diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index 04975c29197c85ef01735dae002a2c5788422ada..bce447f22cdfc3eee3f9889a8db07a4643fb4eb6 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -269,6 +269,8 @@ def main( tags = ["Training", vision_model, text_model] if test_run: tags += ["Test Run"] + if vision_checkpoint is not None: + tags += ["Vision Pretraining"] wandb_logger = WandbLogger(project=f"LiT4RSVQA", log_model=True, tags=tags, # keyword arg directly to wandb.init() diff --git a/train_rsvqa.py b/train_rsvqa.py index 48e92c1a4f4e72f24e84d681b86b7b9e868b75a6..e34232fdcc0cc7213901a353d98d5ad53ebbd052 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -292,6 +292,8 @@ def main( tags = ["Training", vision_model, text_model] if test_run: tags += ["Test Run"] + if vision_checkpoint is not None: + tags += ["Vision Pretraining"] wandb_logger = WandbLogger(project=f"LiT4RSVQA", log_model=True, tags=tags, # keyword arg directly to wandb.init()