From b517a1ad7857d982cb2cb42506aadc342f27f0b9 Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Wed, 17 May 2023 11:34:43 +0200
Subject: [PATCH] changing to hparams from chappuis

---
 train_rsvqa.py | 39 +++++++++++++++++++++++++--------------
 1 file changed, 25 insertions(+), 14 deletions(-)

diff --git a/train_rsvqa.py b/train_rsvqa.py
index 0734109..85b34b5 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -234,18 +234,22 @@ def overwrite_vision_weights(model, vision_checkpoint):
 
 
 def mutan(fusion_in: int, fusion_out: int):
-    opt = {
-        'dim_hv': fusion_in,
-        'dim_hq': fusion_in,
-        'dim_mm': fusion_out,
-        'dropout_hv': 0.1,
-        'dropout_hq': 0.1,
-        'R': 10
+    opt = {  # values copied from chappuis' code
+        'dim_v': 1200,  # 2048,
+        'dim_q': 1200,  # 2400,
+        'dim_hv': 360,
+        'dim_hq': 360,
+        'dim_mm': 360,
+        'R': 10,
+        'dropout_v': 0.5,
+        'dropout_q': 0.5,
+        'activation_v': 'tanh',
+        'activation_q': 'tanh',
+        'dropout_hv': 0,
+        'dropout_hq': 0
     }
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    f = MutanFusion(visual_embedding=False,
-                    question_embedding=False,
-                    opt=opt).to(device)
+    f = MutanFusion(opt=opt).to(device)
     return f
 
 
@@ -274,8 +278,11 @@ def main(
     img_size = 120
     channels = 10
 
-    fusion_in = 1200
-    fusion_out = 360
+    fusion_in = 1200  # Chappuis  FUSION_IN
+    # !! QUESTION_OUT is not usable in ConfigILMv0.3.0
+    fusion_out = 360  # Chappuis MUTAN_OUT
+
+
     model_config = ILMConfiguration(
         timm_model_name=vision_model,
         hf_model_name=text_model,
@@ -283,11 +290,14 @@ def main(
         image_size=img_size,
         channels=channels,
         network_type=ILMType.VQA_CLASSIFICATION,
-        visual_features_out=2048,
+        visual_features_out=2048,  # Chappuis VISUAL_OUT
         fusion_in=fusion_in,
         fusion_out=fusion_out,
         fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out),
-        fusion_hidden=256
+        fusion_hidden=256,  # Chappuis FUSION_HDIDDEN,
+        v_dropout_rate=0.5,  # Chappuis DROPOUT_V
+        t_dropout_rate=0.5,  # Chappuis DROPOUT_Q
+        fusion_dropout_rate=0.5  # Chappuis DROPOUT_F
     )
 
     # Key is available by wandb, project name can be chosen at will
@@ -298,6 +308,7 @@ def main(
         tags += ["Test Run"]
     if vision_checkpoint is not None:
         tags += ["Vision Pretraining"]
+    tags += ["Chappuis HParams"]
     wandb_logger = WandbLogger(project=f"LiT4RSVQA",
                                log_model=True,
                                tags=tags,  # keyword arg directly to wandb.init()
-- 
GitLab