From 7f06e4c5b4b5a181c602e1cef045fe0af97b1330 Mon Sep 17 00:00:00 2001
From: leonard <l.hackel@tu-berlin.de>
Date: Thu, 20 Jun 2024 17:21:02 +0200
Subject: [PATCH] adding workers and bs to compare call

---
 scripts/improve_hf_training.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/scripts/improve_hf_training.py b/scripts/improve_hf_training.py
index 49844b4..2e55d68 100644
--- a/scripts/improve_hf_training.py
+++ b/scripts/improve_hf_training.py
@@ -143,7 +143,12 @@ def main(
     torch.set_float32_matmul_precision("medium")
 
     # load the model from Huggingface Hub and evaluate to get a baseline
-    compare_results = download_and_evaluate_model(model_name=model_name, limit_test_batches=5 if test_run else None)
+    compare_results = download_and_evaluate_model(
+        model_name=model_name,
+        limit_test_batches=5 if test_run else None,
+        batch_size=bs,
+        num_workers_dataloader=workers,
+    )
     compare_metric = compare_results["AveragePrecision"]["macro"]
 
     # train the model with the given hyperparameters based on the config of the downloaded model
-- 
GitLab