diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index 09ef31ba023535f2273cfd324e296244135ceec8..bd24a695dd265853b97313c297df146b37da9bc2 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -242,9 +242,9 @@ def overwrite_vision_weights(model, vision_checkpoint): def main( vision_model: str = "mobilevit_s", text_model: str = "prajjwal1/bert-tiny", - lr: float = 1e-3, + lr: float = 5e-4, epochs: int = 10, - batch_size: int = 32, + batch_size: int = 512, seed: int = 42, data_dir: str = None, test_run: bool = False, @@ -284,13 +284,13 @@ def main( tags=tags, # keyword arg directly to wandb.init() ) - monitor = "val/f1" - monitor_str = "F1_score" + monitor = "val/Accuracy (Average)" + monitor_str = "AA" # checkpointing checkpoint_callback = ModelCheckpoint( monitor="val/f1", dirpath="./checkpoints", - filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" + + filename=f"{wandb_logger.experiment.name}-seed=" + str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + f"{monitor}" + ":.3f}", auto_insert_metric_name=False, @@ -330,6 +330,20 @@ def main( tokenizer=hf_tokenizer ) + wandb_logger.log_hyperparams( + { + "Vision Model": vision_model, + "Text Model": text_model, + "Learning Rate": lr, + "Epochs": epochs, + "Batch Size": batch_size, + "Seed": seed, + "# Workers": num_workers_dataloader, + "Vision Checkpoint": vision_checkpoint, + "GPU": torch.cuda.get_device_name() + } + ) + trainer.fit(model=model, datamodule=dm) trainer.test(model=model, datamodule=dm, ckpt_path="best") diff --git a/train_rsvqa.py b/train_rsvqa.py index 82b18c866727f84aa4630dbc3f5d76c04d75d552..05e2907de206024f5f121e6f144aec46c072c321 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -252,9 +252,9 @@ def mutan(fusion_in: int, fusion_out: int): def main( vision_model: str = "resnet152", text_model: str = "bert-base-uncased", - lr: float = 1e-3, + lr: float = 5e-4, epochs: int = 10, - batch_size: int = 32, + batch_size: int = 512, seed: int = 42, data_dir: str = None, test_run: bool = False, @@ -301,13 +301,13 @@ def main( tags=tags, # keyword arg directly to wandb.init() ) - monitor = "val/f1" - monitor_str = "F1_score" + monitor = "val/Accuracy (Average)" + monitor_str = "AA" # checkpointing checkpoint_callback = ModelCheckpoint( monitor="val/f1", dirpath="./checkpoints", - filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" + + filename=f"{wandb_logger.experiment.name}-seed=" + str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + f"{monitor}" + ":.3f}", auto_insert_metric_name=False, @@ -344,6 +344,20 @@ def main( tokenizer=hf_tokenizer ) + wandb_logger.log_hyperparams( + { + "Vision Model": vision_model, + "Text Model": text_model, + "Learning Rate": lr, + "Epochs": epochs, + "Batch Size": batch_size, + "Seed": seed, + "# Workers": num_workers_dataloader, + "Vision Checkpoint": vision_checkpoint, + "GPU": torch.cuda.get_device_name() + } + ) + trainer.fit(model=model, datamodule=dm) trainer.test(model=model, datamodule=dm, ckpt_path="best")