diff --git a/scripts/improve_hf_training.py b/scripts/improve_hf_training.py index 49844b439ed820cb4ec5d114628b9385315d7f39..2e55d682e64895670d46ce6ef40e0209580c0449 100644 --- a/scripts/improve_hf_training.py +++ b/scripts/improve_hf_training.py @@ -143,7 +143,12 @@ def main( torch.set_float32_matmul_precision("medium") # load the model from Huggingface Hub and evaluate to get a baseline - compare_results = download_and_evaluate_model(model_name=model_name, limit_test_batches=5 if test_run else None) + compare_results = download_and_evaluate_model( + model_name=model_name, + limit_test_batches=5 if test_run else None, + batch_size=bs, + num_workers_dataloader=workers, + ) compare_metric = compare_results["AveragePrecision"]["macro"] # train the model with the given hyperparameters based on the config of the downloaded model