From 24bf97f5b94c23de532adf3c06487769fb6dde2d Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Thu, 1 Jun 2023 14:01:59 +0200
Subject: [PATCH] preparing for release

---
 README.md      | 141 ++++++++++++++++++
 fusion.py      | 148 -------------------
 results.csv    |  41 ++++++
 train_rsvqa.py | 381 -------------------------------------------------
 4 files changed, 182 insertions(+), 529 deletions(-)
 create mode 100644 README.md
 delete mode 100644 fusion.py
 create mode 100644 results.csv
 delete mode 100644 train_rsvqa.py

diff --git a/README.md b/README.md
new file mode 100644
index 0000000..6f851d7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,141 @@
+# LiT-4-RSVQA: Lightweight Transformer-based Visual Question Answering in Remote Sensing
+This repository contains code of the paper [`LiT-4-RSVQA: Lightweight Transformer-based 
+Visual Question Answering in Remote Sensing`]() presented at the IEEE International Geoscience and Remote Sensing Symposium (IGARSS) in July 2023. This work has been done at the [Remote Sensing Image Analysis group](https://rsim.berlin/) by [Leonard Hackel](https://rsim.berlin/team/members/leonard-hackel), [Kai Norman Clasen](https://rsim.berlin/team/members/kai-norman-clasen), [Mahdyar Ravanbakhsh](https://rsim.berlin/team/alumni) and [Begüm Demir](https://rsim.berlin/team/members/begum-demir).
+
+If you use this code, please cite our paper given below:
+
+> L. Hackel and K. N. Clasen and M. Ravanbakhsh and B. DemÑ–r, "LiT-4-RSVQA: Lightweight Transformer-based Visual Question Answering in Remote Sensing", IEEE International Geoscience and Remote Sensing Symposium, Pasadena, California, 2023.
+
+```bibtex
+@article{Hackel2023,
+    author = "Hackel, Leonard and Clasen, Kai Norman and Ravanbakhsh, Mahdyar and Demir, Begüm",
+    doi = "",
+    journal = "IEEE International Geoscience and Remote Sensing Symposium, Pasadena, California",
+    month = "",
+    title = "{LiT-4-RSVQA}: Lightweight Transformer-based Visual Question Answering in Remote Sensing",
+    url = "",
+    year = "2023"
+}
+```
+
+## Introduction
+Visual question answering (VQA) methods in remote sensing (RS) aim to answer natural language questions with respect to an RS image.
+Most of the existing methods require a large amount of computational resources, which limits their application in operational scenarios in RS.
+To address this issue, in this paper we present an effective lightweight transformer-based VQA in RS (LiT-4-RSVQA) architecture for efficient and accurate VQA in RS.
+Our architecture consists of: 
+i) a lightweight text encoder module; 
+ii) a lightweight image encoder module; 
+iii) a fusion module; 
+and iv) a classification module.
+The experimental results obtained on a VQA benchmark dataset demonstrate that our proposed LiT-4-RSVQA architecture provides accurate VQA results while significantly reducing the computational requirements on the executing hardware.
+
+## Dataset
+Two datasets were used for training:
+1. Pre-training: [BigEarthNet](https://bigearth.net/) (only Sentinel-2)
+2. Training: [RSVQAxBEN](https://zenodo.org/record/5084904) extended with all available Sentinel-2 10m and 20m bands included in [BigEarthNet](https://bigearth.net/) 
+
+## Prerequisites
+The code in this repository uses the requirements specified in `conda_env.yml`. To install the requirements, call `conda env create -f conda_env.yml`.
+
+The code is tested in Ubuntu 20.04 on NVidia A100 GPUs with Driver Version 515.105.01 (CUDA Version: 11.7)
+
+Create an account at [Weights & Biases](https://wandb.ai/) and save the `API KEY` in an environment variable named `WANDB_API_KEY` (e.g. `export WANDB_API_KEY=abcdefg` in the executing shell) for logging of results.  
+
+## Pre-Training
+To pretrain the image encoder, call
+```
+python pretrain_lit4rsvqa.py --batch-size 512 --epochs 100 --model-name <MODEL_TO_TRAIN> --data-dir <path/to/BigEarthNet/Encoder/BigEarthNetEncoded.lmdb>
+```
+Optionally the following parameters can be set:  
+`--lr FLOAT` the learning rate  
+`--num-workers-dataloader INT` number of workers used in pytorch dataloader  
+`--seed INT` seed for initialization of RNG  
+`--test-run` or `--no-test-run` (default) only use limited number of batches to see if different steps of training work
+see also: `python pretrain_lit4rsvqa.py --help`
+
+The training will validate every 5th epoch and save a checkpoint with the best F1 score.
+After training/validation, the best checkpoint will be tested.
+Training progress and results are saved locally and on [Weights & Biases](https://wandb.ai/).
+
+For the publication, the seed 42 was used.
+
+## Training and Evaluation
+
+To train the vqa encoder end-to-end, call
+```
+python train_lit4rsvqa.py --vision-model <VISION_MODEL_TO_TRAIN> --data-dir <path/to/TRAININGFILES>
+```
+or with pre-trained vision encoder
+```
+python train_lit4rsvqa.py --vision-model <VISION_MODEL_TO_TRAIN> --vision-checkpoint <path/to/vision/weights> --data-dir <path/to/TRAININGFILES>
+```
+where the weights are pytorch weights or a lightning checkpoint including the weights.
+The `TRAININGFILES` should be set up as follows:
+```
+<TRAININGFILES>
+├── BigEarthNetEncoded.lmdb
+│   ├── data.mdb
+│   └── lock.mdb
+└── VQA_RSVQAxBEN
+    ├── RSVQAxBEN_QA_test.json
+    ├── RSVQAxBEN_QA_test_subset.json
+    ├── RSVQAxBEN_QA_train.json
+    ├── RSVQAxBEN_QA_train_subset.json
+    ├── RSVQAxBEN_QA_val.json
+    └── RSVQAxBEN_QA_val_subset.json
+```
+
+Optionally the following parameters can be set:  
+`--lr FLOAT` the learning rate  
+`--num-workers-dataloader INT` number of workers used in pytorch dataloader  
+`--seed INT` seed for initialization of RNG  
+`--test-run` or `--no-test-run` (default) only use limited number of batches to see if different steps of training work  
+`--text-model STR` the learning rate  
+`--epochs INT` number of epochs - this changes the LR schedule automatically  
+`--batch-size INT` the batch size  
+`--matmul-precision STR` precision used in [torch matmul](https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html)  
+see also: `python train_lit4rsvqa.py --help`
+
+For the publication, the seeds 42, 1337 and 2023 were used.
+
+## Authors
+**Leonard Hackel**
+https://rsim.berlin/team/members/leonard-hackel
+
+**Kai Norman Clasen**
+https://rsim.berlin/team/members/kai-norman-clasen 
+
+**Mahdyar Ravanbakhsh**
+https://rsim.berlin/team/alumni
+
+**Begüm Demir**
+https://rsim.berlin/team/members/begum-demir
+
+For questions, requests and concerns, please contact [Leonard Hackel via mail](mailto:l.hackel@tu-berlin.de)
+
+## License
+The code in this repository is licensed under the **MIT License**:
+```
+MIT License
+
+Copyright (c) 2023 Leonard Hackel
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+```
+
diff --git a/fusion.py b/fusion.py
deleted file mode 100644
index f215eeb..0000000
--- a/fusion.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# from https://github.com/Cadene/vqa.pytorch/blob/master/vqa/models/fusion.py
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.autograd import Variable
-
-
-class AbstractFusion(nn.Module):
-
-    def __init__(self, opt={}):
-        super(AbstractFusion, self).__init__()
-        self.opt = opt
-
-    def forward(self, input_v, input_q):
-        raise NotImplementedError
-
-
-class MLBFusion(AbstractFusion):
-
-    def __init__(self, opt):
-        super(MLBFusion, self).__init__(opt)
-        # Modules
-        if 'dim_v' in self.opt:
-            self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h'])
-        else:
-            print('Warning fusion.py: no visual embedding before fusion')
-
-        if 'dim_q' in self.opt:
-            self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h'])
-        else:
-            print('Warning fusion.py: no question embedding before fusion')
-
-    def forward(self, input_v, input_q):
-        # visual (cnn features)
-        if 'dim_v' in self.opt:
-            x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
-            x_v = self.linear_v(x_v)
-            if 'activation_v' in self.opt:
-                x_v = getattr(F, self.opt['activation_v'])(x_v)
-        else:
-            x_v = input_v
-        # question (rnn features)
-        if 'dim_q' in self.opt:
-            x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
-            x_q = self.linear_q(x_q)
-            if 'activation_q' in self.opt:
-                x_q = getattr(F, self.opt['activation_q'])(x_q)
-        else:
-            x_q = input_q
-        # hadamard product
-        x_mm = torch.mul(x_q, x_v)
-        return x_mm
-
-
-class MutanFusion(AbstractFusion):
-
-    def __init__(self, opt, visual_embedding=True, question_embedding=True):
-        super(MutanFusion, self).__init__(opt)
-        self.visual_embedding = visual_embedding
-        self.question_embedding = question_embedding
-        # Modules
-        if self.visual_embedding:
-            self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv'])
-        else:
-            print('Warning fusion.py: no visual embedding before fusion')
-
-        if self.question_embedding:
-            self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq'])
-        else:
-            print('Warning fusion.py: no question embedding before fusion')
-
-        self.list_linear_hv = nn.ModuleList([
-            nn.Linear(self.opt['dim_hv'], self.opt['dim_mm'])
-            for i in range(self.opt['R'])])
-
-        self.list_linear_hq = nn.ModuleList([
-            nn.Linear(self.opt['dim_hq'], self.opt['dim_mm'])
-            for i in range(self.opt['R'])])
-
-    def forward(self, input_v, input_q):
-        if input_v.dim() != input_q.dim() and input_v.dim() != 2:
-            raise ValueError
-        batch_size = input_v.size(0)
-
-        if self.visual_embedding:
-            x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
-            x_v = self.linear_v(x_v)
-            if 'activation_v' in self.opt:
-                x_v = getattr(F, self.opt['activation_v'])(x_v)
-        else:
-            x_v = input_v
-
-        if self.question_embedding:
-            x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
-            x_q = self.linear_q(x_q)
-            if 'activation_q' in self.opt:
-                x_q = getattr(F, self.opt['activation_q'])(x_q)
-        else:
-            x_q = input_q
-
-        x_mm = []
-        for i in range(self.opt['R']):
-
-            x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training)
-            x_hv = self.list_linear_hv[i](x_hv)
-            if 'activation_hv' in self.opt:
-                x_hv = getattr(F, self.opt['activation_hv'])(x_hv)
-
-            x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training)
-            x_hq = self.list_linear_hq[i](x_hq)
-            if 'activation_hq' in self.opt:
-                x_hq = getattr(F, self.opt['activation_hq'])(x_hq)
-
-            x_mm.append(torch.mul(x_hq, x_hv))
-
-        x_mm = torch.stack(x_mm, dim=1)
-        x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm'])
-
-        if 'activation_mm' in self.opt:
-            x_mm = getattr(F, self.opt['activation_mm'])(x_mm)
-
-        return x_mm
-
-
-class MutanFusion2d(MutanFusion):
-
-    def __init__(self, opt, visual_embedding=True, question_embedding=True):
-        super(MutanFusion2d, self).__init__(opt,
-                                            visual_embedding,
-                                            question_embedding)
-
-    def forward(self, input_v, input_q):
-        if input_v.dim() != input_q.dim() and input_v.dim() != 3:
-            raise ValueError
-        batch_size = input_v.size(0)
-        weight_height = input_v.size(1)
-        dim_hv = input_v.size(2)
-        dim_hq = input_q.size(2)
-        if not input_v.is_contiguous():
-            input_v = input_v.contiguous()
-        if not input_q.is_contiguous():
-            input_q = input_q.contiguous()
-        x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv'])
-        x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq'])
-        x_mm = super().forward(x_v, x_q)
-        x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm'])
-        return x_mm
diff --git a/results.csv b/results.csv
new file mode 100644
index 0000000..64f2cbf
--- /dev/null
+++ b/results.csv
@@ -0,0 +1,41 @@
+,Vision Encoder,pretrained,Text Encoder,Fusion,LULC,YN,AA,OA,Params,Flops,seed
+Mean,,,,,36.26,87.83,62.04,79.15,92.6,4.4,
+Variance,,,,,1.1,0.3,0.7,0.43,0,0,
+,Deit3 Base,yes,Bert Tiny,Simple,0.3743,0.8809,0.6276,0.7957,92.6,4.4,42
+,Deit3 Base,yes,Bert Tiny,Simple,0.3611,0.8789,0.62,0.7918,92.6,4.4,1337
+,Deit3 Base,yes,Bert Tiny,Simple,0.3524,0.875,0.6137,0.7871,92.6,4.4,2023
+Mean,,,,,29.53,85.21,57.36,75.84,92.6,4.4,
+Variance,,,,,3.1,1.39,2.24,1.67,0,0,
+,Deit3 Base,no,Bert Tiny,Simple,0.3265,0.8652,0.5958,0.7745,92.6,4.4,42
+,Deit3 Base,no,Bert Tiny,Simple,0.2948,0.8535,0.5741,0.7595,92.6,4.4,1337
+,Deit3 Base,no,Bert Tiny,Simple,0.2645,0.8375,0.551,0.7411,92.6,4.4,2023
+Mean,,,,,38.37,88.65,63.51,80.19,11,0.3,
+Variance,,,,,1.17,0.39,0.64,0.39,0,0,
+,Deit Tiny,yes,Bert Tiny,Simple,0.3784,0.8904,0.6344,0.8042,11,0.3,42
+,Deit Tiny,yes,Bert Tiny,Simple,0.3755,0.8827,0.6291,0.7974,11,0.3,1337
+,Deit Tiny,yes,Bert Tiny,Simple,0.3971,0.8865,0.6418,0.8042,11,0.3,2023
+Mean,,,,,32.53,86.8,59.67,77.67,11,0.3,
+Variance,,,,,1.93,0.43,1.16,0.66,0,0,
+,Deit Tiny,no,Bert Tiny,Simple,0.3063,0.8631,0.5847,0.7694,11,0.3,42
+,Deit Tiny,no,Bert Tiny,Simple,0.3247,0.8704,0.5975,0.7786,11,0.3,1337
+,Deit Tiny,no,Bert Tiny,Simple,0.3449,0.8706,0.6078,0.7822,11,0.3,2023
+Mean,,,,,34.43,89.02,61.73,79.95,10.4,0.5,
+Variance,,,,,1.94,0.27,0.96,0.23,0,0,
+,MobileViT S,yes,Bert Tiny,Simple,0.3545,0.8928,0.6236,0.8022,10.4,0.5,42
+,MobileViT S,yes,Bert Tiny,Simple,0.322,0.8904,0.6062,0.7982,10.4,0.5,1337
+,MobileViT S,yes,Bert Tiny,Simple,0.3565,0.8875,0.622,0.7982,10.4,0.5,2023
+,,,,,27.18,85.51,56.34,75.7,10.4,0.5,
+Variance,,,,,3.18,2.19,2.63,2.32,0,0,
+,MobileViT S,no,Bert Tiny,Simple,0.3067,0.8802,0.5935,0.7838,10.4,0.5,42
+,MobileViT S,no,Bert Tiny,Simple,0.2641,0.8399,0.552,0.743,10.4,0.5,1337
+,MobileViT S,no,Bert Tiny,Simple,0.2446,0.8451,0.5448,0.7441,10.4,0.5,2023
+Mean,,,,,39.5,89.72,64.61,81.27,8.1,0.6,
+Variance,,,,,2.04,0.65,1.35,0.88,0,0,
+,XciT Nano,yes,Bert Tiny,Simple,0.4185,0.9047,0.6616,0.8229,8.1,0.6,42
+,XciT Nano,yes,Bert Tiny,Simple,0.3821,0.8942,0.6381,0.8081,8.1,0.6,1337
+,XciT Nano,yes,Bert Tiny,Simple,0.3843,0.8927,0.6385,0.8072,8.1,0.6,2023
+Mean,,,,,21.45,81.97,51.71,71.79,8.1,0.6,
+Variance,,,,,4.07,4.66,4.36,4.56,0,0,
+,XciT Nano,no,Bert Tiny,Simple,0.1676,0.7662,0.4669,0.6655,8.1,0.6,42
+,XciT Nano,no,Bert Tiny,Simple,0.2403,0.8517,0.546,0.7489,8.1,0.6,1337
+,XciT Nano,no,Bert Tiny,Simple,0.2355,0.8412,0.5383,0.7393,8.1,0.6,2023
diff --git a/train_rsvqa.py b/train_rsvqa.py
deleted file mode 100644
index 85b34b5..0000000
--- a/train_rsvqa.py
+++ /dev/null
@@ -1,381 +0,0 @@
-# import packages
-import pytorch_lightning as pl
-import torch
-import torch.nn.functional as F
-from torch import optim
-from tqdm import tqdm
-
-from configilm import ConfigILM
-from configilm.ConfigILM import ILMConfiguration, ILMType
-from configilm.ConfigILM import get_hf_model as get_huggingface_model
-from configilm.extra.RSVQAxBEN_DataModule_LMDB_Encoder import RSVQAxBENDataModule
-from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir
-import typer
-import os
-from os.path import isfile
-import wandb
-from pytorch_lightning.loggers.wandb import WandbLogger
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.callbacks import EarlyStopping
-from pytorch_lightning.callbacks import LearningRateMonitor
-from sklearn.metrics import accuracy_score
-from torchmetrics.classification import MultilabelF1Score
-from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
-from fvcore.nn import parameter_count, FlopCountAnalysis
-
-
-from fusion import MutanFusion
-
-__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
-os.environ["WANDB_START_METHOD"] = "thread"
-wandb_api_key = os.environ["WANDB_API_KEY"]
-
-
-class LitVisionEncoder(pl.LightningModule):
-    """
-    Wrapper around a pytorch module, allowing this module to be used in automatic
-    training with pytorch lightning.
-    Among other things, the wrapper allows us to do automatic training and removes the
-    need to manage data on different devices (e.g. GPU and CPU).
-    """
-
-    def __init__(
-            self,
-            config: ConfigILM.ILMConfiguration,
-            lr: float = 1e-3,
-    ):
-        super().__init__()
-        self.lr = lr
-        self.config = config
-        self.model = ConfigILM.ConfigILM(config)
-
-    def _disassemble_batch(self, batch):
-        images, questions, labels = batch
-        # transposing tensor, needed for Huggingface-Dataloader combination
-        questions = torch.tensor(
-            [x.tolist() for x in questions], device=self.device
-        ).T.int()
-        return (images, questions), labels
-
-    def training_step(self, batch, batch_idx):
-        x, y = self._disassemble_batch(batch)
-        x_hat = self.model(x)
-        loss = F.binary_cross_entropy_with_logits(x_hat, y)
-        self.log("train/loss", loss)
-        return {"loss": loss}
-
-    def configure_optimizers(self):
-        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
-
-        # these are steps if interval is set to step
-        max_intervals = int(self.trainer.max_epochs *
-                            len(self.trainer.datamodule.train_ds) /
-                            self.trainer.datamodule.batch_size)
-        warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0
-
-        print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps")
-
-        lr_scheduler = {
-            'scheduler': LinearWarmupCosineAnnealingLR(
-                optimizer,
-                warmup_epochs=warmup,
-                max_epochs=max_intervals,
-                warmup_start_lr=self.lr / 10,
-                eta_min=self.lr / 10
-            ),
-            'name': 'learning_rate',
-            'interval': "step",
-            'frequency': 1
-        }
-        return [optimizer], [lr_scheduler]
-
-    def validation_step(self, batch, batch_idx):
-        x, y = self._disassemble_batch(batch)
-        x_hat = self.model(x)
-        loss = F.binary_cross_entropy_with_logits(x_hat, y)
-        return {"loss": loss, "outputs": x_hat, "labels": y}
-
-    def validation_epoch_end(self, outputs):
-        metrics = self.get_metrics(outputs)
-
-        self.log("val/loss", metrics["avg_loss"])
-        self.log("val/f1", metrics["avg_f1_score"])
-        self.log("val/Accuracy (LULC)", metrics["accuracy"]["LULC"])
-        self.log("val/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"])
-        self.log("val/Accuracy (Overall)", metrics["accuracy"]["Overall"])
-        self.log("val/Accuracy (Average)", metrics["accuracy"]["Average"])
-
-    def test_step(self, batch, batch_idx):
-        x, y = self._disassemble_batch(batch)
-        x_hat = self.model(x)
-        loss = F.binary_cross_entropy_with_logits(x_hat, y)
-        return {"loss": loss, "outputs": x_hat, "labels": y}
-
-    def test_epoch_end(self, outputs):
-        metrics = self.get_metrics(outputs)
-
-        self.log("test/loss", metrics["avg_loss"])
-        self.log("test/f1", metrics["avg_f1_score"])
-        self.log("test/Accuracy (LULC)", metrics["accuracy"]["LULC"])
-        self.log("test/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"])
-        self.log("test/Accuracy (Overall)", metrics["accuracy"]["Overall"])
-        self.log("test/Accuracy (Average)", metrics["accuracy"]["Average"])
-
-    def forward(self, batch):
-        # because we are a wrapper, we call the inner function manually
-        return self.model(batch)
-
-    def get_metrics(self, outputs):
-        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
-        logits = torch.cat([x["outputs"].cpu() for x in outputs], 0)
-        labels = torch.cat(
-            [x["labels"].cpu() for x in outputs], 0
-        )  # Tensor of size (#samples x classes)
-
-        selected_answers = self.trainer.datamodule.selected_answers
-
-        argmax_out = torch.argmax(logits, dim=1)
-        argmax_lbl = torch.argmax(labels, dim=1)
-
-        # get answers and predictions per type
-        yn_preds = []
-        yn_gts = []
-        lulc_preds = []
-        lulc_gts = []
-
-        for i, ans in enumerate(tqdm(argmax_lbl, desc="Counting answers")):
-            # Yes/No question
-            if selected_answers[ans] in ["yes", "no"]:
-
-                # stored for global Yes/No
-                yn_preds.append(argmax_out[i])
-                yn_gts.append(ans)
-
-            # LC question
-            else:
-                # stored for global LC
-                lulc_preds.append(argmax_out[i])
-                lulc_gts.append(ans)
-
-        acc_yn = accuracy_score(yn_gts, yn_preds)
-        acc_lulc = accuracy_score(lulc_gts, lulc_preds)
-
-        accuracy_dict = {
-            "Yes/No": acc_yn,
-            "LULC": acc_lulc,
-            "Overall": accuracy_score(
-                argmax_lbl, argmax_out
-            ),  # micro average on classes
-            "Average": (acc_yn + acc_lulc) / 2,  # macro average on types
-        }
-
-        f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to(
-            logits.device
-        )(logits, labels)
-
-        avg_f1_score = float(
-            torch.sum(f1_score) / self.config.classes
-        )  # macro average f1 score
-
-        return {
-            "avg_loss": avg_loss,
-            "avg_f1_score": avg_f1_score,
-            "accuracy": accuracy_dict,
-        }
-
-
-def overwrite_vision_weights(model, vision_checkpoint):
-    if vision_checkpoint is None:
-        return model
-    if not isfile(vision_checkpoint):
-        print("Pretrained vision model not available, cannot load checkpoint")
-        return model
-    # load weights
-    # get model and pretrained state dicts
-    if torch.cuda.is_available():
-        pretrained_dict = torch.load(vision_checkpoint)
-    else:
-        pretrained_dict = torch.load(
-            vision_checkpoint, map_location=torch.device("cpu")
-        )
-    model_dict = model.state_dict()
-
-    # filter out unnecessary keys
-    # this allows to load lightning or pytorch model loading
-    if "pytorch-lightning_version" in pretrained_dict.keys():
-        # checkpoint is a Pytorch-Lightning Checkpoint
-        pretrained_dict = {
-            k: v
-            for k, v in pretrained_dict["state_dict"].items()
-            if k in model_dict
-        }
-    else:
-        pretrained_dict = {
-            k: v for k, v in pretrained_dict.items() if k in model_dict
-        }
-
-    # filter keys that have a size mismatch
-    mismatch_keys = [
-        x
-        for x in pretrained_dict.keys()
-        if pretrained_dict[x].shape != model_dict[x].shape
-    ]
-    for key in mismatch_keys:
-        del pretrained_dict[key]
-        print(f"Key '{key}' size mismatch, removing from loading")
-
-    # overwrite entries in the existing state dict
-    model_dict.update(pretrained_dict)
-
-    # load the new state dict
-    model.load_state_dict(model_dict)
-    print("Vision Model checkpoint loaded")
-    return model
-
-
-def mutan(fusion_in: int, fusion_out: int):
-    opt = {  # values copied from chappuis' code
-        'dim_v': 1200,  # 2048,
-        'dim_q': 1200,  # 2400,
-        'dim_hv': 360,
-        'dim_hq': 360,
-        'dim_mm': 360,
-        'R': 10,
-        'dropout_v': 0.5,
-        'dropout_q': 0.5,
-        'activation_v': 'tanh',
-        'activation_q': 'tanh',
-        'dropout_hv': 0,
-        'dropout_hq': 0
-    }
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    f = MutanFusion(opt=opt).to(device)
-    return f
-
-
-def main(
-        vision_model: str = "resnet152",
-        text_model: str = "bert-base-uncased",
-        lr: float = 5e-4,
-        epochs: int = 10,
-        batch_size: int = 512,
-        seed: int = 42,
-        data_dir: str = None,
-        test_run: bool = False,
-        num_workers_dataloader: int = 4,
-        vision_checkpoint: str = None
-):
-    if test_run:
-        max_img_index = 10 * batch_size
-        epochs = 10
-    else:
-        max_img_index = -1
-
-    pl.seed_everything(seed, workers=True)
-
-    torch.set_float32_matmul_precision("medium")
-
-    img_size = 120
-    channels = 10
-
-    fusion_in = 1200  # Chappuis  FUSION_IN
-    # !! QUESTION_OUT is not usable in ConfigILMv0.3.0
-    fusion_out = 360  # Chappuis MUTAN_OUT
-
-
-    model_config = ILMConfiguration(
-        timm_model_name=vision_model,
-        hf_model_name=text_model,
-        classes=1000,
-        image_size=img_size,
-        channels=channels,
-        network_type=ILMType.VQA_CLASSIFICATION,
-        visual_features_out=2048,  # Chappuis VISUAL_OUT
-        fusion_in=fusion_in,
-        fusion_out=fusion_out,
-        fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out),
-        fusion_hidden=256,  # Chappuis FUSION_HDIDDEN,
-        v_dropout_rate=0.5,  # Chappuis DROPOUT_V
-        t_dropout_rate=0.5,  # Chappuis DROPOUT_Q
-        fusion_dropout_rate=0.5  # Chappuis DROPOUT_F
-    )
-
-    # Key is available by wandb, project name can be chosen at will
-    wandb.login(key=wandb_api_key)
-
-    tags = ["Training", vision_model, text_model]
-    if test_run:
-        tags += ["Test Run"]
-    if vision_checkpoint is not None:
-        tags += ["Vision Pretraining"]
-    tags += ["Chappuis HParams"]
-    wandb_logger = WandbLogger(project=f"LiT4RSVQA",
-                               log_model=True,
-                               tags=tags,  # keyword arg directly to wandb.init()
-                               )
-
-    monitor = "val/Accuracy (Average)"
-    monitor_str = "AA"
-    # checkpointing
-    checkpoint_callback = ModelCheckpoint(
-        monitor="val/f1",
-        dirpath="./checkpoints",
-        filename=f"{wandb_logger.experiment.name}-seed=" +
-                 str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" +
-                 f"{monitor}" + ":.3f}",
-        auto_insert_metric_name=False,
-        save_top_k=1,
-        mode="max",
-        save_last=True
-    )
-    early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00,
-                                            patience=25, verbose=False, mode="max")
-    lr_monitor = LearningRateMonitor(logging_interval='step')
-
-    trainer = pl.Trainer(
-        max_epochs=epochs,
-        accelerator="auto",
-        log_every_n_steps=5,
-        logger=wandb_logger,
-        check_val_every_n_epoch=2,
-        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
-    )
-
-    model = LitVisionEncoder(config=model_config, lr=lr)
-    model = overwrite_vision_weights(model, vision_checkpoint)
-
-    hf_tokenizer, _ = get_huggingface_model(
-        model_name=text_model, load_pretrained_if_available=False
-    )
-    dm = RSVQAxBENDataModule(
-        data_dir=resolve_ben_data_dir(data_dir=data_dir),
-        img_size=(channels, img_size, img_size),
-        num_workers_dataloader=num_workers_dataloader,
-        batch_size=batch_size,
-        max_img_idx=max_img_index,
-        tokenizer=hf_tokenizer
-    )
-
-    wandb_logger.log_hyperparams(
-        {
-            "Vision Model": vision_model,
-            "Text Model": text_model,
-            "Learning Rate": lr,
-            "Epochs": epochs,
-            "Batch Size": batch_size,
-            "Seed": seed,
-            "# Workers": num_workers_dataloader,
-            "Vision Checkpoint": vision_checkpoint,
-            "GPU": torch.cuda.get_device_name()
-        }
-    )
-
-    trainer.fit(model=model, datamodule=dm)
-    trainer.test(model=model, datamodule=dm, ckpt_path="best")
-
-    wandb.finish()
-    print("=== Training finished ===")
-
-
-if __name__ == "__main__":
-    typer.run(main)
-- 
GitLab