From b517a1ad7857d982cb2cb42506aadc342f27f0b9 Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Wed, 17 May 2023 11:34:43 +0200 Subject: [PATCH] changing to hparams from chappuis --- train_rsvqa.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/train_rsvqa.py b/train_rsvqa.py index 0734109..85b34b5 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -234,18 +234,22 @@ def overwrite_vision_weights(model, vision_checkpoint): def mutan(fusion_in: int, fusion_out: int): - opt = { - 'dim_hv': fusion_in, - 'dim_hq': fusion_in, - 'dim_mm': fusion_out, - 'dropout_hv': 0.1, - 'dropout_hq': 0.1, - 'R': 10 + opt = { # values copied from chappuis' code + 'dim_v': 1200, # 2048, + 'dim_q': 1200, # 2400, + 'dim_hv': 360, + 'dim_hq': 360, + 'dim_mm': 360, + 'R': 10, + 'dropout_v': 0.5, + 'dropout_q': 0.5, + 'activation_v': 'tanh', + 'activation_q': 'tanh', + 'dropout_hv': 0, + 'dropout_hq': 0 } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - f = MutanFusion(visual_embedding=False, - question_embedding=False, - opt=opt).to(device) + f = MutanFusion(opt=opt).to(device) return f @@ -274,8 +278,11 @@ def main( img_size = 120 channels = 10 - fusion_in = 1200 - fusion_out = 360 + fusion_in = 1200 # Chappuis FUSION_IN + # !! QUESTION_OUT is not usable in ConfigILMv0.3.0 + fusion_out = 360 # Chappuis MUTAN_OUT + + model_config = ILMConfiguration( timm_model_name=vision_model, hf_model_name=text_model, @@ -283,11 +290,14 @@ def main( image_size=img_size, channels=channels, network_type=ILMType.VQA_CLASSIFICATION, - visual_features_out=2048, + visual_features_out=2048, # Chappuis VISUAL_OUT fusion_in=fusion_in, fusion_out=fusion_out, fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out), - fusion_hidden=256 + fusion_hidden=256, # Chappuis FUSION_HDIDDEN, + v_dropout_rate=0.5, # Chappuis DROPOUT_V + t_dropout_rate=0.5, # Chappuis DROPOUT_Q + fusion_dropout_rate=0.5 # Chappuis DROPOUT_F ) # Key is available by wandb, project name can be chosen at will @@ -298,6 +308,7 @@ def main( tags += ["Test Run"] if vision_checkpoint is not None: tags += ["Vision Pretraining"] + tags += ["Chappuis HParams"] wandb_logger = WandbLogger(project=f"LiT4RSVQA", log_model=True, tags=tags, # keyword arg directly to wandb.init() -- GitLab