From 24bf97f5b94c23de532adf3c06487769fb6dde2d Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Thu, 1 Jun 2023 14:01:59 +0200 Subject: [PATCH] preparing for release --- README.md | 141 ++++++++++++++++++ fusion.py | 148 ------------------- results.csv | 41 ++++++ train_rsvqa.py | 381 ------------------------------------------------- 4 files changed, 182 insertions(+), 529 deletions(-) create mode 100644 README.md delete mode 100644 fusion.py create mode 100644 results.csv delete mode 100644 train_rsvqa.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..6f851d7 --- /dev/null +++ b/README.md @@ -0,0 +1,141 @@ +# LiT-4-RSVQA: Lightweight Transformer-based Visual Question Answering in Remote Sensing +This repository contains code of the paper [`LiT-4-RSVQA: Lightweight Transformer-based +Visual Question Answering in Remote Sensing`]() presented at the IEEE International Geoscience and Remote Sensing Symposium (IGARSS) in July 2023. This work has been done at the [Remote Sensing Image Analysis group](https://rsim.berlin/) by [Leonard Hackel](https://rsim.berlin/team/members/leonard-hackel), [Kai Norman Clasen](https://rsim.berlin/team/members/kai-norman-clasen), [Mahdyar Ravanbakhsh](https://rsim.berlin/team/alumni) and [Begüm Demir](https://rsim.berlin/team/members/begum-demir). + +If you use this code, please cite our paper given below: + +> L. Hackel and K. N. Clasen and M. Ravanbakhsh and B. Demіr, "LiT-4-RSVQA: Lightweight Transformer-based Visual Question Answering in Remote Sensing", IEEE International Geoscience and Remote Sensing Symposium, Pasadena, California, 2023. + +```bibtex +@article{Hackel2023, + author = "Hackel, Leonard and Clasen, Kai Norman and Ravanbakhsh, Mahdyar and Demir, Begüm", + doi = "", + journal = "IEEE International Geoscience and Remote Sensing Symposium, Pasadena, California", + month = "", + title = "{LiT-4-RSVQA}: Lightweight Transformer-based Visual Question Answering in Remote Sensing", + url = "", + year = "2023" +} +``` + +## Introduction +Visual question answering (VQA) methods in remote sensing (RS) aim to answer natural language questions with respect to an RS image. +Most of the existing methods require a large amount of computational resources, which limits their application in operational scenarios in RS. +To address this issue, in this paper we present an effective lightweight transformer-based VQA in RS (LiT-4-RSVQA) architecture for efficient and accurate VQA in RS. +Our architecture consists of: +i) a lightweight text encoder module; +ii) a lightweight image encoder module; +iii) a fusion module; +and iv) a classification module. +The experimental results obtained on a VQA benchmark dataset demonstrate that our proposed LiT-4-RSVQA architecture provides accurate VQA results while significantly reducing the computational requirements on the executing hardware. + +## Dataset +Two datasets were used for training: +1. Pre-training: [BigEarthNet](https://bigearth.net/) (only Sentinel-2) +2. Training: [RSVQAxBEN](https://zenodo.org/record/5084904) extended with all available Sentinel-2 10m and 20m bands included in [BigEarthNet](https://bigearth.net/) + +## Prerequisites +The code in this repository uses the requirements specified in `conda_env.yml`. To install the requirements, call `conda env create -f conda_env.yml`. + +The code is tested in Ubuntu 20.04 on NVidia A100 GPUs with Driver Version 515.105.01 (CUDA Version: 11.7) + +Create an account at [Weights & Biases](https://wandb.ai/) and save the `API KEY` in an environment variable named `WANDB_API_KEY` (e.g. `export WANDB_API_KEY=abcdefg` in the executing shell) for logging of results. + +## Pre-Training +To pretrain the image encoder, call +``` +python pretrain_lit4rsvqa.py --batch-size 512 --epochs 100 --model-name <MODEL_TO_TRAIN> --data-dir <path/to/BigEarthNet/Encoder/BigEarthNetEncoded.lmdb> +``` +Optionally the following parameters can be set: +`--lr FLOAT` the learning rate +`--num-workers-dataloader INT` number of workers used in pytorch dataloader +`--seed INT` seed for initialization of RNG +`--test-run` or `--no-test-run` (default) only use limited number of batches to see if different steps of training work +see also: `python pretrain_lit4rsvqa.py --help` + +The training will validate every 5th epoch and save a checkpoint with the best F1 score. +After training/validation, the best checkpoint will be tested. +Training progress and results are saved locally and on [Weights & Biases](https://wandb.ai/). + +For the publication, the seed 42 was used. + +## Training and Evaluation + +To train the vqa encoder end-to-end, call +``` +python train_lit4rsvqa.py --vision-model <VISION_MODEL_TO_TRAIN> --data-dir <path/to/TRAININGFILES> +``` +or with pre-trained vision encoder +``` +python train_lit4rsvqa.py --vision-model <VISION_MODEL_TO_TRAIN> --vision-checkpoint <path/to/vision/weights> --data-dir <path/to/TRAININGFILES> +``` +where the weights are pytorch weights or a lightning checkpoint including the weights. +The `TRAININGFILES` should be set up as follows: +``` +<TRAININGFILES> +├── BigEarthNetEncoded.lmdb +│ ├── data.mdb +│ └── lock.mdb +└── VQA_RSVQAxBEN + ├── RSVQAxBEN_QA_test.json + ├── RSVQAxBEN_QA_test_subset.json + ├── RSVQAxBEN_QA_train.json + ├── RSVQAxBEN_QA_train_subset.json + ├── RSVQAxBEN_QA_val.json + └── RSVQAxBEN_QA_val_subset.json +``` + +Optionally the following parameters can be set: +`--lr FLOAT` the learning rate +`--num-workers-dataloader INT` number of workers used in pytorch dataloader +`--seed INT` seed for initialization of RNG +`--test-run` or `--no-test-run` (default) only use limited number of batches to see if different steps of training work +`--text-model STR` the learning rate +`--epochs INT` number of epochs - this changes the LR schedule automatically +`--batch-size INT` the batch size +`--matmul-precision STR` precision used in [torch matmul](https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html) +see also: `python train_lit4rsvqa.py --help` + +For the publication, the seeds 42, 1337 and 2023 were used. + +## Authors +**Leonard Hackel** +https://rsim.berlin/team/members/leonard-hackel + +**Kai Norman Clasen** +https://rsim.berlin/team/members/kai-norman-clasen + +**Mahdyar Ravanbakhsh** +https://rsim.berlin/team/alumni + +**Begüm Demir** +https://rsim.berlin/team/members/begum-demir + +For questions, requests and concerns, please contact [Leonard Hackel via mail](mailto:l.hackel@tu-berlin.de) + +## License +The code in this repository is licensed under the **MIT License**: +``` +MIT License + +Copyright (c) 2023 Leonard Hackel + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` + diff --git a/fusion.py b/fusion.py deleted file mode 100644 index f215eeb..0000000 --- a/fusion.py +++ /dev/null @@ -1,148 +0,0 @@ -# from https://github.com/Cadene/vqa.pytorch/blob/master/vqa/models/fusion.py - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable - - -class AbstractFusion(nn.Module): - - def __init__(self, opt={}): - super(AbstractFusion, self).__init__() - self.opt = opt - - def forward(self, input_v, input_q): - raise NotImplementedError - - -class MLBFusion(AbstractFusion): - - def __init__(self, opt): - super(MLBFusion, self).__init__(opt) - # Modules - if 'dim_v' in self.opt: - self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h']) - else: - print('Warning fusion.py: no visual embedding before fusion') - - if 'dim_q' in self.opt: - self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h']) - else: - print('Warning fusion.py: no question embedding before fusion') - - def forward(self, input_v, input_q): - # visual (cnn features) - if 'dim_v' in self.opt: - x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) - x_v = self.linear_v(x_v) - if 'activation_v' in self.opt: - x_v = getattr(F, self.opt['activation_v'])(x_v) - else: - x_v = input_v - # question (rnn features) - if 'dim_q' in self.opt: - x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) - x_q = self.linear_q(x_q) - if 'activation_q' in self.opt: - x_q = getattr(F, self.opt['activation_q'])(x_q) - else: - x_q = input_q - # hadamard product - x_mm = torch.mul(x_q, x_v) - return x_mm - - -class MutanFusion(AbstractFusion): - - def __init__(self, opt, visual_embedding=True, question_embedding=True): - super(MutanFusion, self).__init__(opt) - self.visual_embedding = visual_embedding - self.question_embedding = question_embedding - # Modules - if self.visual_embedding: - self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv']) - else: - print('Warning fusion.py: no visual embedding before fusion') - - if self.question_embedding: - self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq']) - else: - print('Warning fusion.py: no question embedding before fusion') - - self.list_linear_hv = nn.ModuleList([ - nn.Linear(self.opt['dim_hv'], self.opt['dim_mm']) - for i in range(self.opt['R'])]) - - self.list_linear_hq = nn.ModuleList([ - nn.Linear(self.opt['dim_hq'], self.opt['dim_mm']) - for i in range(self.opt['R'])]) - - def forward(self, input_v, input_q): - if input_v.dim() != input_q.dim() and input_v.dim() != 2: - raise ValueError - batch_size = input_v.size(0) - - if self.visual_embedding: - x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) - x_v = self.linear_v(x_v) - if 'activation_v' in self.opt: - x_v = getattr(F, self.opt['activation_v'])(x_v) - else: - x_v = input_v - - if self.question_embedding: - x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) - x_q = self.linear_q(x_q) - if 'activation_q' in self.opt: - x_q = getattr(F, self.opt['activation_q'])(x_q) - else: - x_q = input_q - - x_mm = [] - for i in range(self.opt['R']): - - x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training) - x_hv = self.list_linear_hv[i](x_hv) - if 'activation_hv' in self.opt: - x_hv = getattr(F, self.opt['activation_hv'])(x_hv) - - x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training) - x_hq = self.list_linear_hq[i](x_hq) - if 'activation_hq' in self.opt: - x_hq = getattr(F, self.opt['activation_hq'])(x_hq) - - x_mm.append(torch.mul(x_hq, x_hv)) - - x_mm = torch.stack(x_mm, dim=1) - x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm']) - - if 'activation_mm' in self.opt: - x_mm = getattr(F, self.opt['activation_mm'])(x_mm) - - return x_mm - - -class MutanFusion2d(MutanFusion): - - def __init__(self, opt, visual_embedding=True, question_embedding=True): - super(MutanFusion2d, self).__init__(opt, - visual_embedding, - question_embedding) - - def forward(self, input_v, input_q): - if input_v.dim() != input_q.dim() and input_v.dim() != 3: - raise ValueError - batch_size = input_v.size(0) - weight_height = input_v.size(1) - dim_hv = input_v.size(2) - dim_hq = input_q.size(2) - if not input_v.is_contiguous(): - input_v = input_v.contiguous() - if not input_q.is_contiguous(): - input_q = input_q.contiguous() - x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv']) - x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq']) - x_mm = super().forward(x_v, x_q) - x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm']) - return x_mm diff --git a/results.csv b/results.csv new file mode 100644 index 0000000..64f2cbf --- /dev/null +++ b/results.csv @@ -0,0 +1,41 @@ +,Vision Encoder,pretrained,Text Encoder,Fusion,LULC,YN,AA,OA,Params,Flops,seed +Mean,,,,,36.26,87.83,62.04,79.15,92.6,4.4, +Variance,,,,,1.1,0.3,0.7,0.43,0,0, +,Deit3 Base,yes,Bert Tiny,Simple,0.3743,0.8809,0.6276,0.7957,92.6,4.4,42 +,Deit3 Base,yes,Bert Tiny,Simple,0.3611,0.8789,0.62,0.7918,92.6,4.4,1337 +,Deit3 Base,yes,Bert Tiny,Simple,0.3524,0.875,0.6137,0.7871,92.6,4.4,2023 +Mean,,,,,29.53,85.21,57.36,75.84,92.6,4.4, +Variance,,,,,3.1,1.39,2.24,1.67,0,0, +,Deit3 Base,no,Bert Tiny,Simple,0.3265,0.8652,0.5958,0.7745,92.6,4.4,42 +,Deit3 Base,no,Bert Tiny,Simple,0.2948,0.8535,0.5741,0.7595,92.6,4.4,1337 +,Deit3 Base,no,Bert Tiny,Simple,0.2645,0.8375,0.551,0.7411,92.6,4.4,2023 +Mean,,,,,38.37,88.65,63.51,80.19,11,0.3, +Variance,,,,,1.17,0.39,0.64,0.39,0,0, +,Deit Tiny,yes,Bert Tiny,Simple,0.3784,0.8904,0.6344,0.8042,11,0.3,42 +,Deit Tiny,yes,Bert Tiny,Simple,0.3755,0.8827,0.6291,0.7974,11,0.3,1337 +,Deit Tiny,yes,Bert Tiny,Simple,0.3971,0.8865,0.6418,0.8042,11,0.3,2023 +Mean,,,,,32.53,86.8,59.67,77.67,11,0.3, +Variance,,,,,1.93,0.43,1.16,0.66,0,0, +,Deit Tiny,no,Bert Tiny,Simple,0.3063,0.8631,0.5847,0.7694,11,0.3,42 +,Deit Tiny,no,Bert Tiny,Simple,0.3247,0.8704,0.5975,0.7786,11,0.3,1337 +,Deit Tiny,no,Bert Tiny,Simple,0.3449,0.8706,0.6078,0.7822,11,0.3,2023 +Mean,,,,,34.43,89.02,61.73,79.95,10.4,0.5, +Variance,,,,,1.94,0.27,0.96,0.23,0,0, +,MobileViT S,yes,Bert Tiny,Simple,0.3545,0.8928,0.6236,0.8022,10.4,0.5,42 +,MobileViT S,yes,Bert Tiny,Simple,0.322,0.8904,0.6062,0.7982,10.4,0.5,1337 +,MobileViT S,yes,Bert Tiny,Simple,0.3565,0.8875,0.622,0.7982,10.4,0.5,2023 +,,,,,27.18,85.51,56.34,75.7,10.4,0.5, +Variance,,,,,3.18,2.19,2.63,2.32,0,0, +,MobileViT S,no,Bert Tiny,Simple,0.3067,0.8802,0.5935,0.7838,10.4,0.5,42 +,MobileViT S,no,Bert Tiny,Simple,0.2641,0.8399,0.552,0.743,10.4,0.5,1337 +,MobileViT S,no,Bert Tiny,Simple,0.2446,0.8451,0.5448,0.7441,10.4,0.5,2023 +Mean,,,,,39.5,89.72,64.61,81.27,8.1,0.6, +Variance,,,,,2.04,0.65,1.35,0.88,0,0, +,XciT Nano,yes,Bert Tiny,Simple,0.4185,0.9047,0.6616,0.8229,8.1,0.6,42 +,XciT Nano,yes,Bert Tiny,Simple,0.3821,0.8942,0.6381,0.8081,8.1,0.6,1337 +,XciT Nano,yes,Bert Tiny,Simple,0.3843,0.8927,0.6385,0.8072,8.1,0.6,2023 +Mean,,,,,21.45,81.97,51.71,71.79,8.1,0.6, +Variance,,,,,4.07,4.66,4.36,4.56,0,0, +,XciT Nano,no,Bert Tiny,Simple,0.1676,0.7662,0.4669,0.6655,8.1,0.6,42 +,XciT Nano,no,Bert Tiny,Simple,0.2403,0.8517,0.546,0.7489,8.1,0.6,1337 +,XciT Nano,no,Bert Tiny,Simple,0.2355,0.8412,0.5383,0.7393,8.1,0.6,2023 diff --git a/train_rsvqa.py b/train_rsvqa.py deleted file mode 100644 index 85b34b5..0000000 --- a/train_rsvqa.py +++ /dev/null @@ -1,381 +0,0 @@ -# import packages -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch import optim -from tqdm import tqdm - -from configilm import ConfigILM -from configilm.ConfigILM import ILMConfiguration, ILMType -from configilm.ConfigILM import get_hf_model as get_huggingface_model -from configilm.extra.RSVQAxBEN_DataModule_LMDB_Encoder import RSVQAxBENDataModule -from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir -import typer -import os -from os.path import isfile -import wandb -from pytorch_lightning.loggers.wandb import WandbLogger -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.callbacks import LearningRateMonitor -from sklearn.metrics import accuracy_score -from torchmetrics.classification import MultilabelF1Score -from LinWarCosAnLR import LinearWarmupCosineAnnealingLR -from fvcore.nn import parameter_count, FlopCountAnalysis - - -from fusion import MutanFusion - -__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin" -os.environ["WANDB_START_METHOD"] = "thread" -wandb_api_key = os.environ["WANDB_API_KEY"] - - -class LitVisionEncoder(pl.LightningModule): - """ - Wrapper around a pytorch module, allowing this module to be used in automatic - training with pytorch lightning. - Among other things, the wrapper allows us to do automatic training and removes the - need to manage data on different devices (e.g. GPU and CPU). - """ - - def __init__( - self, - config: ConfigILM.ILMConfiguration, - lr: float = 1e-3, - ): - super().__init__() - self.lr = lr - self.config = config - self.model = ConfigILM.ConfigILM(config) - - def _disassemble_batch(self, batch): - images, questions, labels = batch - # transposing tensor, needed for Huggingface-Dataloader combination - questions = torch.tensor( - [x.tolist() for x in questions], device=self.device - ).T.int() - return (images, questions), labels - - def training_step(self, batch, batch_idx): - x, y = self._disassemble_batch(batch) - x_hat = self.model(x) - loss = F.binary_cross_entropy_with_logits(x_hat, y) - self.log("train/loss", loss) - return {"loss": loss} - - def configure_optimizers(self): - optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01) - - # these are steps if interval is set to step - max_intervals = int(self.trainer.max_epochs * - len(self.trainer.datamodule.train_ds) / - self.trainer.datamodule.batch_size) - warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0 - - print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps") - - lr_scheduler = { - 'scheduler': LinearWarmupCosineAnnealingLR( - optimizer, - warmup_epochs=warmup, - max_epochs=max_intervals, - warmup_start_lr=self.lr / 10, - eta_min=self.lr / 10 - ), - 'name': 'learning_rate', - 'interval': "step", - 'frequency': 1 - } - return [optimizer], [lr_scheduler] - - def validation_step(self, batch, batch_idx): - x, y = self._disassemble_batch(batch) - x_hat = self.model(x) - loss = F.binary_cross_entropy_with_logits(x_hat, y) - return {"loss": loss, "outputs": x_hat, "labels": y} - - def validation_epoch_end(self, outputs): - metrics = self.get_metrics(outputs) - - self.log("val/loss", metrics["avg_loss"]) - self.log("val/f1", metrics["avg_f1_score"]) - self.log("val/Accuracy (LULC)", metrics["accuracy"]["LULC"]) - self.log("val/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) - self.log("val/Accuracy (Overall)", metrics["accuracy"]["Overall"]) - self.log("val/Accuracy (Average)", metrics["accuracy"]["Average"]) - - def test_step(self, batch, batch_idx): - x, y = self._disassemble_batch(batch) - x_hat = self.model(x) - loss = F.binary_cross_entropy_with_logits(x_hat, y) - return {"loss": loss, "outputs": x_hat, "labels": y} - - def test_epoch_end(self, outputs): - metrics = self.get_metrics(outputs) - - self.log("test/loss", metrics["avg_loss"]) - self.log("test/f1", metrics["avg_f1_score"]) - self.log("test/Accuracy (LULC)", metrics["accuracy"]["LULC"]) - self.log("test/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) - self.log("test/Accuracy (Overall)", metrics["accuracy"]["Overall"]) - self.log("test/Accuracy (Average)", metrics["accuracy"]["Average"]) - - def forward(self, batch): - # because we are a wrapper, we call the inner function manually - return self.model(batch) - - def get_metrics(self, outputs): - avg_loss = torch.stack([x["loss"] for x in outputs]).mean() - logits = torch.cat([x["outputs"].cpu() for x in outputs], 0) - labels = torch.cat( - [x["labels"].cpu() for x in outputs], 0 - ) # Tensor of size (#samples x classes) - - selected_answers = self.trainer.datamodule.selected_answers - - argmax_out = torch.argmax(logits, dim=1) - argmax_lbl = torch.argmax(labels, dim=1) - - # get answers and predictions per type - yn_preds = [] - yn_gts = [] - lulc_preds = [] - lulc_gts = [] - - for i, ans in enumerate(tqdm(argmax_lbl, desc="Counting answers")): - # Yes/No question - if selected_answers[ans] in ["yes", "no"]: - - # stored for global Yes/No - yn_preds.append(argmax_out[i]) - yn_gts.append(ans) - - # LC question - else: - # stored for global LC - lulc_preds.append(argmax_out[i]) - lulc_gts.append(ans) - - acc_yn = accuracy_score(yn_gts, yn_preds) - acc_lulc = accuracy_score(lulc_gts, lulc_preds) - - accuracy_dict = { - "Yes/No": acc_yn, - "LULC": acc_lulc, - "Overall": accuracy_score( - argmax_lbl, argmax_out - ), # micro average on classes - "Average": (acc_yn + acc_lulc) / 2, # macro average on types - } - - f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to( - logits.device - )(logits, labels) - - avg_f1_score = float( - torch.sum(f1_score) / self.config.classes - ) # macro average f1 score - - return { - "avg_loss": avg_loss, - "avg_f1_score": avg_f1_score, - "accuracy": accuracy_dict, - } - - -def overwrite_vision_weights(model, vision_checkpoint): - if vision_checkpoint is None: - return model - if not isfile(vision_checkpoint): - print("Pretrained vision model not available, cannot load checkpoint") - return model - # load weights - # get model and pretrained state dicts - if torch.cuda.is_available(): - pretrained_dict = torch.load(vision_checkpoint) - else: - pretrained_dict = torch.load( - vision_checkpoint, map_location=torch.device("cpu") - ) - model_dict = model.state_dict() - - # filter out unnecessary keys - # this allows to load lightning or pytorch model loading - if "pytorch-lightning_version" in pretrained_dict.keys(): - # checkpoint is a Pytorch-Lightning Checkpoint - pretrained_dict = { - k: v - for k, v in pretrained_dict["state_dict"].items() - if k in model_dict - } - else: - pretrained_dict = { - k: v for k, v in pretrained_dict.items() if k in model_dict - } - - # filter keys that have a size mismatch - mismatch_keys = [ - x - for x in pretrained_dict.keys() - if pretrained_dict[x].shape != model_dict[x].shape - ] - for key in mismatch_keys: - del pretrained_dict[key] - print(f"Key '{key}' size mismatch, removing from loading") - - # overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - - # load the new state dict - model.load_state_dict(model_dict) - print("Vision Model checkpoint loaded") - return model - - -def mutan(fusion_in: int, fusion_out: int): - opt = { # values copied from chappuis' code - 'dim_v': 1200, # 2048, - 'dim_q': 1200, # 2400, - 'dim_hv': 360, - 'dim_hq': 360, - 'dim_mm': 360, - 'R': 10, - 'dropout_v': 0.5, - 'dropout_q': 0.5, - 'activation_v': 'tanh', - 'activation_q': 'tanh', - 'dropout_hv': 0, - 'dropout_hq': 0 - } - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - f = MutanFusion(opt=opt).to(device) - return f - - -def main( - vision_model: str = "resnet152", - text_model: str = "bert-base-uncased", - lr: float = 5e-4, - epochs: int = 10, - batch_size: int = 512, - seed: int = 42, - data_dir: str = None, - test_run: bool = False, - num_workers_dataloader: int = 4, - vision_checkpoint: str = None -): - if test_run: - max_img_index = 10 * batch_size - epochs = 10 - else: - max_img_index = -1 - - pl.seed_everything(seed, workers=True) - - torch.set_float32_matmul_precision("medium") - - img_size = 120 - channels = 10 - - fusion_in = 1200 # Chappuis FUSION_IN - # !! QUESTION_OUT is not usable in ConfigILMv0.3.0 - fusion_out = 360 # Chappuis MUTAN_OUT - - - model_config = ILMConfiguration( - timm_model_name=vision_model, - hf_model_name=text_model, - classes=1000, - image_size=img_size, - channels=channels, - network_type=ILMType.VQA_CLASSIFICATION, - visual_features_out=2048, # Chappuis VISUAL_OUT - fusion_in=fusion_in, - fusion_out=fusion_out, - fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out), - fusion_hidden=256, # Chappuis FUSION_HDIDDEN, - v_dropout_rate=0.5, # Chappuis DROPOUT_V - t_dropout_rate=0.5, # Chappuis DROPOUT_Q - fusion_dropout_rate=0.5 # Chappuis DROPOUT_F - ) - - # Key is available by wandb, project name can be chosen at will - wandb.login(key=wandb_api_key) - - tags = ["Training", vision_model, text_model] - if test_run: - tags += ["Test Run"] - if vision_checkpoint is not None: - tags += ["Vision Pretraining"] - tags += ["Chappuis HParams"] - wandb_logger = WandbLogger(project=f"LiT4RSVQA", - log_model=True, - tags=tags, # keyword arg directly to wandb.init() - ) - - monitor = "val/Accuracy (Average)" - monitor_str = "AA" - # checkpointing - checkpoint_callback = ModelCheckpoint( - monitor="val/f1", - dirpath="./checkpoints", - filename=f"{wandb_logger.experiment.name}-seed=" + - str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + - f"{monitor}" + ":.3f}", - auto_insert_metric_name=False, - save_top_k=1, - mode="max", - save_last=True - ) - early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00, - patience=25, verbose=False, mode="max") - lr_monitor = LearningRateMonitor(logging_interval='step') - - trainer = pl.Trainer( - max_epochs=epochs, - accelerator="auto", - log_every_n_steps=5, - logger=wandb_logger, - check_val_every_n_epoch=2, - callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], - ) - - model = LitVisionEncoder(config=model_config, lr=lr) - model = overwrite_vision_weights(model, vision_checkpoint) - - hf_tokenizer, _ = get_huggingface_model( - model_name=text_model, load_pretrained_if_available=False - ) - dm = RSVQAxBENDataModule( - data_dir=resolve_ben_data_dir(data_dir=data_dir), - img_size=(channels, img_size, img_size), - num_workers_dataloader=num_workers_dataloader, - batch_size=batch_size, - max_img_idx=max_img_index, - tokenizer=hf_tokenizer - ) - - wandb_logger.log_hyperparams( - { - "Vision Model": vision_model, - "Text Model": text_model, - "Learning Rate": lr, - "Epochs": epochs, - "Batch Size": batch_size, - "Seed": seed, - "# Workers": num_workers_dataloader, - "Vision Checkpoint": vision_checkpoint, - "GPU": torch.cuda.get_device_name() - } - ) - - trainer.fit(model=model, datamodule=dm) - trainer.test(model=model, datamodule=dm, ckpt_path="best") - - wandb.finish() - print("=== Training finished ===") - - -if __name__ == "__main__": - typer.run(main) -- GitLab