diff --git a/train_rsvqa.py b/train_rsvqa.py index 6d1e448c71a2478859a9d5206fc9bb367cc4ccfb..0734109d3b1db9a109a34790d62bd3d1138bdfab 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -240,7 +240,7 @@ def mutan(fusion_in: int, fusion_out: int): 'dim_mm': fusion_out, 'dropout_hv': 0.1, 'dropout_hq': 0.1, - 'R': 1 + 'R': 10 } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") f = MutanFusion(visual_embedding=False,