From 7f06e4c5b4b5a181c602e1cef045fe0af97b1330 Mon Sep 17 00:00:00 2001 From: leonard <l.hackel@tu-berlin.de> Date: Thu, 20 Jun 2024 17:21:02 +0200 Subject: [PATCH] adding workers and bs to compare call --- scripts/improve_hf_training.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/improve_hf_training.py b/scripts/improve_hf_training.py index 49844b4..2e55d68 100644 --- a/scripts/improve_hf_training.py +++ b/scripts/improve_hf_training.py @@ -143,7 +143,12 @@ def main( torch.set_float32_matmul_precision("medium") # load the model from Huggingface Hub and evaluate to get a baseline - compare_results = download_and_evaluate_model(model_name=model_name, limit_test_batches=5 if test_run else None) + compare_results = download_and_evaluate_model( + model_name=model_name, + limit_test_batches=5 if test_run else None, + batch_size=bs, + num_workers_dataloader=workers, + ) compare_metric = compare_results["AveragePrecision"]["macro"] # train the model with the given hyperparameters based on the config of the downloaded model -- GitLab