From ec5e762c42c0f4a41147ad5c68c38166f350ccbb Mon Sep 17 00:00:00 2001
From: leonard <l.hackel@tu-berlin.de>
Date: Thu, 20 Jun 2024 17:28:31 +0200
Subject: [PATCH] fixing warmup and some messages

---
 ben_publication/BENv2ImageClassifier.py | 2 +-
 scripts/improve_hf_training.py          | 4 +++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/ben_publication/BENv2ImageClassifier.py b/ben_publication/BENv2ImageClassifier.py
index 0739230..77736ad 100644
--- a/ben_publication/BENv2ImageClassifier.py
+++ b/ben_publication/BENv2ImageClassifier.py
@@ -35,7 +35,7 @@ class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
     ):
         super().__init__()
         self.lr = lr
-        self.warmup = warmup
+        self.warmup = None if warmup < 0 else warmup
         self.config = config
         assert config.network_type == ILMType.IMAGE_CLASSIFICATION
         assert config.classes == 19
diff --git a/scripts/improve_hf_training.py b/scripts/improve_hf_training.py
index 2e55d68..612ae64 100644
--- a/scripts/improve_hf_training.py
+++ b/scripts/improve_hf_training.py
@@ -188,8 +188,10 @@ def main(
             print(f"Pushing to {push_path}")
             model.push_to_hub(push_path, commit_message=f"Upload {model_name}")
             print("=== Done ===")
+        else:
+            print("=== Skipping upload to Huggingface Hub because no entity was provided ===")
     else:
-        print("=== Skipping upload to Huggingface Hub ===")
+        print("=== Skipping upload to Huggingface Hub because the new model did not improve the compare metric ===")
 
     print("=== Training finished ===")
 
-- 
GitLab