From 7d8737758d29ed99dc91cef5bee6a0f5683a92ff Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Thu, 11 May 2023 14:53:40 +0200
Subject: [PATCH] Changing the lr to be default correct same with batch size
 checkpointing by AA now logging Hyper params

---
 train_lit4rsvqa.py | 24 +++++++++++++++++++-----
 train_rsvqa.py     | 24 +++++++++++++++++++-----
 2 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py
index 09ef31b..bd24a69 100644
--- a/train_lit4rsvqa.py
+++ b/train_lit4rsvqa.py
@@ -242,9 +242,9 @@ def overwrite_vision_weights(model, vision_checkpoint):
 def main(
         vision_model: str = "mobilevit_s",
         text_model: str = "prajjwal1/bert-tiny",
-        lr: float = 1e-3,
+        lr: float = 5e-4,
         epochs: int = 10,
-        batch_size: int = 32,
+        batch_size: int = 512,
         seed: int = 42,
         data_dir: str = None,
         test_run: bool = False,
@@ -284,13 +284,13 @@ def main(
                                tags=tags,  # keyword arg directly to wandb.init()
                                )
 
-    monitor = "val/f1"
-    monitor_str = "F1_score"
+    monitor = "val/Accuracy (Average)"
+    monitor_str = "AA"
     # checkpointing
     checkpoint_callback = ModelCheckpoint(
         monitor="val/f1",
         dirpath="./checkpoints",
-        filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" +
+        filename=f"{wandb_logger.experiment.name}-seed=" +
                  str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" +
                  f"{monitor}" + ":.3f}",
         auto_insert_metric_name=False,
@@ -330,6 +330,20 @@ def main(
         tokenizer=hf_tokenizer
     )
 
+    wandb_logger.log_hyperparams(
+        {
+            "Vision Model": vision_model,
+            "Text Model": text_model,
+            "Learning Rate": lr,
+            "Epochs": epochs,
+            "Batch Size": batch_size,
+            "Seed": seed,
+            "# Workers": num_workers_dataloader,
+            "Vision Checkpoint": vision_checkpoint,
+            "GPU": torch.cuda.get_device_name()
+        }
+    )
+
     trainer.fit(model=model, datamodule=dm)
     trainer.test(model=model, datamodule=dm, ckpt_path="best")
 
diff --git a/train_rsvqa.py b/train_rsvqa.py
index 82b18c8..05e2907 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -252,9 +252,9 @@ def mutan(fusion_in: int, fusion_out: int):
 def main(
         vision_model: str = "resnet152",
         text_model: str = "bert-base-uncased",
-        lr: float = 1e-3,
+        lr: float = 5e-4,
         epochs: int = 10,
-        batch_size: int = 32,
+        batch_size: int = 512,
         seed: int = 42,
         data_dir: str = None,
         test_run: bool = False,
@@ -301,13 +301,13 @@ def main(
                                tags=tags,  # keyword arg directly to wandb.init()
                                )
 
-    monitor = "val/f1"
-    monitor_str = "F1_score"
+    monitor = "val/Accuracy (Average)"
+    monitor_str = "AA"
     # checkpointing
     checkpoint_callback = ModelCheckpoint(
         monitor="val/f1",
         dirpath="./checkpoints",
-        filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" +
+        filename=f"{wandb_logger.experiment.name}-seed=" +
                  str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" +
                  f"{monitor}" + ":.3f}",
         auto_insert_metric_name=False,
@@ -344,6 +344,20 @@ def main(
         tokenizer=hf_tokenizer
     )
 
+    wandb_logger.log_hyperparams(
+        {
+            "Vision Model": vision_model,
+            "Text Model": text_model,
+            "Learning Rate": lr,
+            "Epochs": epochs,
+            "Batch Size": batch_size,
+            "Seed": seed,
+            "# Workers": num_workers_dataloader,
+            "Vision Checkpoint": vision_checkpoint,
+            "GPU": torch.cuda.get_device_name()
+        }
+    )
+
     trainer.fit(model=model, datamodule=dm)
     trainer.test(model=model, datamodule=dm, ckpt_path="best")
 
-- 
GitLab