diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index 3076169ba85337df7a2c5e7b59a1dfceca232a62..09ef31ba023535f2273cfd324e296244135ceec8 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -243,7 +243,7 @@ def main( vision_model: str = "mobilevit_s", text_model: str = "prajjwal1/bert-tiny", lr: float = 1e-3, - epochs: int = 100, + epochs: int = 10, batch_size: int = 32, seed: int = 42, data_dir: str = None, @@ -307,7 +307,7 @@ def main( accelerator="auto", log_every_n_steps=5, logger=wandb_logger, - check_val_every_n_epoch=5, + check_val_every_n_epoch=2, callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], ) diff --git a/train_rsvqa.py b/train_rsvqa.py index 1e2beabfbd048ccfeeafaea3bd6ad4b50a23a9b2..82b18c866727f84aa4630dbc3f5d76c04d75d552 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -253,7 +253,7 @@ def main( vision_model: str = "resnet152", text_model: str = "bert-base-uncased", lr: float = 1e-3, - epochs: int = 100, + epochs: int = 10, batch_size: int = 32, seed: int = 42, data_dir: str = None, @@ -324,7 +324,7 @@ def main( accelerator="auto", log_every_n_steps=5, logger=wandb_logger, - check_val_every_n_epoch=5, + check_val_every_n_epoch=2, callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], )