From d43df488d81aa6d787b8167733565989c615b2be Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Wed, 3 May 2023 08:18:45 +0200
Subject: [PATCH] Adding TransformerRSVQA code and update to configilm v0.3.0
 to make code runable

---
 conda_env.yml  |   2 +-
 fusion.py      | 148 +++++++++++++++++++++
 train_rsvqa.py | 350 +++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 499 insertions(+), 1 deletion(-)
 create mode 100644 fusion.py
 create mode 100644 train_rsvqa.py

diff --git a/conda_env.yml b/conda_env.yml
index 49e1f6a..b92f7df 100644
--- a/conda_env.yml
+++ b/conda_env.yml
@@ -57,7 +57,7 @@ dependencies:
   - pip:
     - attrs==23.1.0
     - colorama==0.4.6
-    - configilm==0.2.0
+    - configilm==0.3.0
     - cycler==0.11.0
     - fonttools==4.39.3
     - kiwisolver==1.4.4
diff --git a/fusion.py b/fusion.py
new file mode 100644
index 0000000..f215eeb
--- /dev/null
+++ b/fusion.py
@@ -0,0 +1,148 @@
+# from https://github.com/Cadene/vqa.pytorch/blob/master/vqa/models/fusion.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+
+class AbstractFusion(nn.Module):
+
+    def __init__(self, opt={}):
+        super(AbstractFusion, self).__init__()
+        self.opt = opt
+
+    def forward(self, input_v, input_q):
+        raise NotImplementedError
+
+
+class MLBFusion(AbstractFusion):
+
+    def __init__(self, opt):
+        super(MLBFusion, self).__init__(opt)
+        # Modules
+        if 'dim_v' in self.opt:
+            self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h'])
+        else:
+            print('Warning fusion.py: no visual embedding before fusion')
+
+        if 'dim_q' in self.opt:
+            self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h'])
+        else:
+            print('Warning fusion.py: no question embedding before fusion')
+
+    def forward(self, input_v, input_q):
+        # visual (cnn features)
+        if 'dim_v' in self.opt:
+            x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
+            x_v = self.linear_v(x_v)
+            if 'activation_v' in self.opt:
+                x_v = getattr(F, self.opt['activation_v'])(x_v)
+        else:
+            x_v = input_v
+        # question (rnn features)
+        if 'dim_q' in self.opt:
+            x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
+            x_q = self.linear_q(x_q)
+            if 'activation_q' in self.opt:
+                x_q = getattr(F, self.opt['activation_q'])(x_q)
+        else:
+            x_q = input_q
+        # hadamard product
+        x_mm = torch.mul(x_q, x_v)
+        return x_mm
+
+
+class MutanFusion(AbstractFusion):
+
+    def __init__(self, opt, visual_embedding=True, question_embedding=True):
+        super(MutanFusion, self).__init__(opt)
+        self.visual_embedding = visual_embedding
+        self.question_embedding = question_embedding
+        # Modules
+        if self.visual_embedding:
+            self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv'])
+        else:
+            print('Warning fusion.py: no visual embedding before fusion')
+
+        if self.question_embedding:
+            self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq'])
+        else:
+            print('Warning fusion.py: no question embedding before fusion')
+
+        self.list_linear_hv = nn.ModuleList([
+            nn.Linear(self.opt['dim_hv'], self.opt['dim_mm'])
+            for i in range(self.opt['R'])])
+
+        self.list_linear_hq = nn.ModuleList([
+            nn.Linear(self.opt['dim_hq'], self.opt['dim_mm'])
+            for i in range(self.opt['R'])])
+
+    def forward(self, input_v, input_q):
+        if input_v.dim() != input_q.dim() and input_v.dim() != 2:
+            raise ValueError
+        batch_size = input_v.size(0)
+
+        if self.visual_embedding:
+            x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
+            x_v = self.linear_v(x_v)
+            if 'activation_v' in self.opt:
+                x_v = getattr(F, self.opt['activation_v'])(x_v)
+        else:
+            x_v = input_v
+
+        if self.question_embedding:
+            x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
+            x_q = self.linear_q(x_q)
+            if 'activation_q' in self.opt:
+                x_q = getattr(F, self.opt['activation_q'])(x_q)
+        else:
+            x_q = input_q
+
+        x_mm = []
+        for i in range(self.opt['R']):
+
+            x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training)
+            x_hv = self.list_linear_hv[i](x_hv)
+            if 'activation_hv' in self.opt:
+                x_hv = getattr(F, self.opt['activation_hv'])(x_hv)
+
+            x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training)
+            x_hq = self.list_linear_hq[i](x_hq)
+            if 'activation_hq' in self.opt:
+                x_hq = getattr(F, self.opt['activation_hq'])(x_hq)
+
+            x_mm.append(torch.mul(x_hq, x_hv))
+
+        x_mm = torch.stack(x_mm, dim=1)
+        x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm'])
+
+        if 'activation_mm' in self.opt:
+            x_mm = getattr(F, self.opt['activation_mm'])(x_mm)
+
+        return x_mm
+
+
+class MutanFusion2d(MutanFusion):
+
+    def __init__(self, opt, visual_embedding=True, question_embedding=True):
+        super(MutanFusion2d, self).__init__(opt,
+                                            visual_embedding,
+                                            question_embedding)
+
+    def forward(self, input_v, input_q):
+        if input_v.dim() != input_q.dim() and input_v.dim() != 3:
+            raise ValueError
+        batch_size = input_v.size(0)
+        weight_height = input_v.size(1)
+        dim_hv = input_v.size(2)
+        dim_hq = input_q.size(2)
+        if not input_v.is_contiguous():
+            input_v = input_v.contiguous()
+        if not input_q.is_contiguous():
+            input_q = input_q.contiguous()
+        x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv'])
+        x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq'])
+        x_mm = super().forward(x_v, x_q)
+        x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm'])
+        return x_mm
diff --git a/train_rsvqa.py b/train_rsvqa.py
new file mode 100644
index 0000000..48e92c1
--- /dev/null
+++ b/train_rsvqa.py
@@ -0,0 +1,350 @@
+# import packages
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+from torch import optim
+from tqdm import tqdm
+
+from configilm import ConfigILM
+from configilm.ConfigILM import ILMConfiguration, ILMType
+from configilm.ConfigILM import get_hf_model as get_huggingface_model
+from configilm.extra.RSVQAxBEN_DataModule_LMDB_Encoder import RSVQAxBENDataModule
+from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir
+import typer
+import os
+from os.path import isfile
+import wandb
+from pytorch_lightning.loggers.wandb import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.callbacks import EarlyStopping
+from pytorch_lightning.callbacks import LearningRateMonitor
+from sklearn.metrics import accuracy_score
+from torchmetrics.classification import MultilabelF1Score
+from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
+
+from fusion import MutanFusion
+
+__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
+os.environ["WANDB_START_METHOD"] = "thread"
+wandb_api_key = os.environ["WANDB_API_KEY"]
+
+
+class LitVisionEncoder(pl.LightningModule):
+    """
+    Wrapper around a pytorch module, allowing this module to be used in automatic
+    training with pytorch lightning.
+    Among other things, the wrapper allows us to do automatic training and removes the
+    need to manage data on different devices (e.g. GPU and CPU).
+    """
+
+    def __init__(
+            self,
+            config: ConfigILM.ILMConfiguration,
+            lr: float = 1e-3,
+    ):
+        super().__init__()
+        self.lr = lr
+        self.config = config
+        self.model = ConfigILM.ConfigILM(config)
+
+    def _disassemble_batch(self, batch):
+        images, questions, labels = batch
+        # transposing tensor, needed for Huggingface-Dataloader combination
+        questions = torch.tensor(
+            [x.tolist() for x in questions], device=self.device
+        ).T.int()
+        return (images, questions), labels
+
+    def training_step(self, batch, batch_idx):
+        x, y = self._disassemble_batch(batch)
+        x_hat = self.model(x)
+        loss = F.binary_cross_entropy_with_logits(x_hat, y)
+        self.log("train/loss", loss)
+        return {"loss": loss}
+
+    def configure_optimizers(self):
+        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
+
+        # these are steps if interval is set to step
+        max_intervals = int(self.trainer.max_epochs *
+                            len(self.trainer.datamodule.train_ds) /
+                            self.trainer.datamodule.batch_size)
+        warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0
+
+        print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps")
+
+        lr_scheduler = {
+            'scheduler': LinearWarmupCosineAnnealingLR(
+                optimizer,
+                warmup_epochs=warmup,
+                max_epochs=max_intervals,
+                warmup_start_lr=self.lr / 10,
+                eta_min=self.lr / 10
+            ),
+            'name': 'learning_rate',
+            'interval': "step",
+            'frequency': 1
+        }
+        return [optimizer], [lr_scheduler]
+
+    def validation_step(self, batch, batch_idx):
+        x, y = self._disassemble_batch(batch)
+        x_hat = self.model(x)
+        loss = F.binary_cross_entropy_with_logits(x_hat, y)
+        return {"loss": loss, "outputs": x_hat, "labels": y}
+
+    def validation_epoch_end(self, outputs):
+        metrics = self.get_metrics(outputs)
+
+        self.log("val/loss", metrics["avg_loss"])
+        self.log("val/f1", metrics["avg_f1_score"])
+        self.log("val/Accuracy (LULC)", metrics["accuracy"]["LULC"])
+        self.log("val/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"])
+        self.log("val/Accuracy (Overall)", metrics["accuracy"]["Overall"])
+        self.log("val/Accuracy (Average)", metrics["accuracy"]["Average"])
+
+    def test_step(self, batch, batch_idx):
+        x, y = self._disassemble_batch(batch)
+        x_hat = self.model(x)
+        loss = F.binary_cross_entropy_with_logits(x_hat, y)
+        return {"loss": loss, "outputs": x_hat, "labels": y}
+
+    def test_epoch_end(self, outputs):
+        metrics = self.get_metrics(outputs)
+
+        self.log("test/loss", metrics["avg_loss"])
+        self.log("test/f1", metrics["avg_f1_score"])
+        self.log("test/Accuracy (LULC)", metrics["accuracy"]["LULC"])
+        self.log("test/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"])
+        self.log("test/Accuracy (Overall)", metrics["accuracy"]["Overall"])
+        self.log("test/Accuracy (Average)", metrics["accuracy"]["Average"])
+
+    def forward(self, batch):
+        # because we are a wrapper, we call the inner function manually
+        return self.model(batch)
+
+    def get_metrics(self, outputs):
+        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
+        logits = torch.cat([x["outputs"].cpu() for x in outputs], 0)
+        labels = torch.cat(
+            [x["labels"].cpu() for x in outputs], 0
+        )  # Tensor of size (#samples x classes)
+
+        selected_answers = self.trainer.datamodule.selected_answers
+
+        argmax_out = torch.argmax(logits, dim=1)
+        argmax_lbl = torch.argmax(labels, dim=1)
+
+        # get answers and predictions per type
+        yn_preds = []
+        yn_gts = []
+        lulc_preds = []
+        lulc_gts = []
+
+        for i, ans in enumerate(tqdm(argmax_lbl, desc="Counting answers")):
+            # Yes/No question
+            if selected_answers[ans] in ["yes", "no"]:
+
+                # stored for global Yes/No
+                yn_preds.append(argmax_out[i])
+                yn_gts.append(ans)
+
+            # LC question
+            else:
+                # stored for global LC
+                lulc_preds.append(argmax_out[i])
+                lulc_gts.append(ans)
+
+        acc_yn = accuracy_score(yn_gts, yn_preds)
+        acc_lulc = accuracy_score(lulc_gts, lulc_preds)
+
+        accuracy_dict = {
+            "Yes/No": acc_yn,
+            "LULC": acc_lulc,
+            "Overall": accuracy_score(
+                argmax_lbl, argmax_out
+            ),  # micro average on classes
+            "Average": (acc_yn + acc_lulc) / 2,  # macro average on types
+        }
+
+        f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to(
+            logits.device
+        )(logits, labels)
+
+        avg_f1_score = float(
+            torch.sum(f1_score) / self.config.classes
+        )  # macro average f1 score
+
+        return {
+            "avg_loss": avg_loss,
+            "avg_f1_score": avg_f1_score,
+            "accuracy": accuracy_dict,
+        }
+
+
+def overwrite_vision_weights(model, vision_checkpoint):
+    if vision_checkpoint is None:
+        return model
+    if not isfile(vision_checkpoint):
+        print("Pretrained vision model not available, cannot load checkpoint")
+        return model
+    # load weights
+    # get model and pretrained state dicts
+    if torch.cuda.is_available():
+        pretrained_dict = torch.load(vision_checkpoint)
+    else:
+        pretrained_dict = torch.load(
+            vision_checkpoint, map_location=torch.device("cpu")
+        )
+    model_dict = model.state_dict()
+
+    # filter out unnecessary keys
+    # this allows to load lightning or pytorch model loading
+    if "pytorch-lightning_version" in pretrained_dict.keys():
+        # checkpoint is a Pytorch-Lightning Checkpoint
+        pretrained_dict = {
+            k: v
+            for k, v in pretrained_dict["state_dict"].items()
+            if k in model_dict
+        }
+    else:
+        pretrained_dict = {
+            k: v for k, v in pretrained_dict.items() if k in model_dict
+        }
+
+    # filter keys that have a size mismatch
+    mismatch_keys = [
+        x
+        for x in pretrained_dict.keys()
+        if pretrained_dict[x].shape != model_dict[x].shape
+    ]
+    for key in mismatch_keys:
+        del pretrained_dict[key]
+        print(f"Key '{key}' size mismatch, removing from loading")
+
+    # overwrite entries in the existing state dict
+    model_dict.update(pretrained_dict)
+
+    # load the new state dict
+    model.load_state_dict(model_dict)
+    print("Vision Model checkpoint loaded")
+    return model
+
+
+def mutan(fusion_in: int, fusion_out: int):
+    opt = {
+        'dim_hv': fusion_in,
+        'dim_hq': fusion_in,
+        'dim_mm': fusion_out,
+        'dropout_hv': 0.1,
+        'dropout_hq': 0.1,
+        'R': 1
+    }
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    f = MutanFusion(visual_embedding=False,
+                    question_embedding=False,
+                    opt=opt).to(device)
+    return f
+
+
+def main(
+        vision_model: str = "resnet152",
+        text_model: str = "bert-base-uncased",
+        lr: float = 1e-3,
+        epochs: int = 100,
+        batch_size: int = 32,
+        seed: int = 42,
+        data_dir: str = None,
+        test_run: bool = False,
+        num_workers_dataloader: int = 4,
+        vision_checkpoint: str = None
+):
+    if test_run:
+        max_img_index = 10 * batch_size
+        epochs = 10
+    else:
+        max_img_index = -1
+
+    pl.seed_everything(seed, workers=True)
+
+    img_size = 120
+    channels = 10
+
+    fusion_in = 1200
+    fusion_out = 360
+    model_config = ILMConfiguration(
+        timm_model_name=vision_model,
+        hf_model_name=text_model,
+        classes=1000,
+        image_size=img_size,
+        channels=channels,
+        network_type=ILMType.VQA_CLASSIFICATION,
+        visual_features_out=2048,
+        fusion_in=fusion_in,
+        fusion_out=fusion_out,
+        fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out),
+        fusion_hidden=256
+    )
+
+    # Key is available by wandb, project name can be chosen at will
+    wandb.login(key=wandb_api_key)
+
+    tags = ["Training", vision_model, text_model]
+    if test_run:
+        tags += ["Test Run"]
+    wandb_logger = WandbLogger(project=f"LiT4RSVQA",
+                               log_model=True,
+                               tags=tags,  # keyword arg directly to wandb.init()
+                               )
+
+    monitor = "val/f1"
+    monitor_str = "F1_score"
+    # checkpointing
+    checkpoint_callback = ModelCheckpoint(
+        monitor="val/f1",
+        dirpath="./checkpoints",
+        filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" +
+                 str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" +
+                 f"{monitor}" + ":.3f}",
+        auto_insert_metric_name=False,
+        save_top_k=1,
+        mode="max",
+        save_last=True
+    )
+    early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00,
+                                            patience=25, verbose=False, mode="max")
+    lr_monitor = LearningRateMonitor(logging_interval='step')
+
+    trainer = pl.Trainer(
+        max_epochs=epochs,
+        accelerator="auto",
+        log_every_n_steps=5,
+        logger=wandb_logger,
+        check_val_every_n_epoch=5,
+        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
+
+    )
+
+    model = LitVisionEncoder(config=model_config, lr=lr)
+    model = overwrite_vision_weights(model, vision_checkpoint)
+
+    hf_tokenizer, _ = get_huggingface_model(
+        model_name=text_model, load_pretrained_if_available=False
+    )
+    dm = RSVQAxBENDataModule(
+        data_dir=resolve_ben_data_dir(data_dir=data_dir, force_mock=True),
+        img_size=(channels, img_size, img_size),
+        num_workers_dataloader=num_workers_dataloader,
+        batch_size=batch_size,
+        max_img_idx=max_img_index,
+    )
+
+    trainer.fit(model=model, datamodule=dm)
+    trainer.test(model=model, datamodule=dm, ckpt_path="best")
+
+    wandb.finish()
+    print("=== Training finished ===")
+
+
+if __name__ == "__main__":
+    typer.run(main)
-- 
GitLab