From d43df488d81aa6d787b8167733565989c615b2be Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Wed, 3 May 2023 08:18:45 +0200 Subject: [PATCH] Adding TransformerRSVQA code and update to configilm v0.3.0 to make code runable --- conda_env.yml | 2 +- fusion.py | 148 +++++++++++++++++++++ train_rsvqa.py | 350 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 499 insertions(+), 1 deletion(-) create mode 100644 fusion.py create mode 100644 train_rsvqa.py diff --git a/conda_env.yml b/conda_env.yml index 49e1f6a..b92f7df 100644 --- a/conda_env.yml +++ b/conda_env.yml @@ -57,7 +57,7 @@ dependencies: - pip: - attrs==23.1.0 - colorama==0.4.6 - - configilm==0.2.0 + - configilm==0.3.0 - cycler==0.11.0 - fonttools==4.39.3 - kiwisolver==1.4.4 diff --git a/fusion.py b/fusion.py new file mode 100644 index 0000000..f215eeb --- /dev/null +++ b/fusion.py @@ -0,0 +1,148 @@ +# from https://github.com/Cadene/vqa.pytorch/blob/master/vqa/models/fusion.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class AbstractFusion(nn.Module): + + def __init__(self, opt={}): + super(AbstractFusion, self).__init__() + self.opt = opt + + def forward(self, input_v, input_q): + raise NotImplementedError + + +class MLBFusion(AbstractFusion): + + def __init__(self, opt): + super(MLBFusion, self).__init__(opt) + # Modules + if 'dim_v' in self.opt: + self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h']) + else: + print('Warning fusion.py: no visual embedding before fusion') + + if 'dim_q' in self.opt: + self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h']) + else: + print('Warning fusion.py: no question embedding before fusion') + + def forward(self, input_v, input_q): + # visual (cnn features) + if 'dim_v' in self.opt: + x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) + x_v = self.linear_v(x_v) + if 'activation_v' in self.opt: + x_v = getattr(F, self.opt['activation_v'])(x_v) + else: + x_v = input_v + # question (rnn features) + if 'dim_q' in self.opt: + x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) + x_q = self.linear_q(x_q) + if 'activation_q' in self.opt: + x_q = getattr(F, self.opt['activation_q'])(x_q) + else: + x_q = input_q + # hadamard product + x_mm = torch.mul(x_q, x_v) + return x_mm + + +class MutanFusion(AbstractFusion): + + def __init__(self, opt, visual_embedding=True, question_embedding=True): + super(MutanFusion, self).__init__(opt) + self.visual_embedding = visual_embedding + self.question_embedding = question_embedding + # Modules + if self.visual_embedding: + self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv']) + else: + print('Warning fusion.py: no visual embedding before fusion') + + if self.question_embedding: + self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq']) + else: + print('Warning fusion.py: no question embedding before fusion') + + self.list_linear_hv = nn.ModuleList([ + nn.Linear(self.opt['dim_hv'], self.opt['dim_mm']) + for i in range(self.opt['R'])]) + + self.list_linear_hq = nn.ModuleList([ + nn.Linear(self.opt['dim_hq'], self.opt['dim_mm']) + for i in range(self.opt['R'])]) + + def forward(self, input_v, input_q): + if input_v.dim() != input_q.dim() and input_v.dim() != 2: + raise ValueError + batch_size = input_v.size(0) + + if self.visual_embedding: + x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) + x_v = self.linear_v(x_v) + if 'activation_v' in self.opt: + x_v = getattr(F, self.opt['activation_v'])(x_v) + else: + x_v = input_v + + if self.question_embedding: + x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) + x_q = self.linear_q(x_q) + if 'activation_q' in self.opt: + x_q = getattr(F, self.opt['activation_q'])(x_q) + else: + x_q = input_q + + x_mm = [] + for i in range(self.opt['R']): + + x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training) + x_hv = self.list_linear_hv[i](x_hv) + if 'activation_hv' in self.opt: + x_hv = getattr(F, self.opt['activation_hv'])(x_hv) + + x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training) + x_hq = self.list_linear_hq[i](x_hq) + if 'activation_hq' in self.opt: + x_hq = getattr(F, self.opt['activation_hq'])(x_hq) + + x_mm.append(torch.mul(x_hq, x_hv)) + + x_mm = torch.stack(x_mm, dim=1) + x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm']) + + if 'activation_mm' in self.opt: + x_mm = getattr(F, self.opt['activation_mm'])(x_mm) + + return x_mm + + +class MutanFusion2d(MutanFusion): + + def __init__(self, opt, visual_embedding=True, question_embedding=True): + super(MutanFusion2d, self).__init__(opt, + visual_embedding, + question_embedding) + + def forward(self, input_v, input_q): + if input_v.dim() != input_q.dim() and input_v.dim() != 3: + raise ValueError + batch_size = input_v.size(0) + weight_height = input_v.size(1) + dim_hv = input_v.size(2) + dim_hq = input_q.size(2) + if not input_v.is_contiguous(): + input_v = input_v.contiguous() + if not input_q.is_contiguous(): + input_q = input_q.contiguous() + x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv']) + x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq']) + x_mm = super().forward(x_v, x_q) + x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm']) + return x_mm diff --git a/train_rsvqa.py b/train_rsvqa.py new file mode 100644 index 0000000..48e92c1 --- /dev/null +++ b/train_rsvqa.py @@ -0,0 +1,350 @@ +# import packages +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch import optim +from tqdm import tqdm + +from configilm import ConfigILM +from configilm.ConfigILM import ILMConfiguration, ILMType +from configilm.ConfigILM import get_hf_model as get_huggingface_model +from configilm.extra.RSVQAxBEN_DataModule_LMDB_Encoder import RSVQAxBENDataModule +from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir +import typer +import os +from os.path import isfile +import wandb +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import LearningRateMonitor +from sklearn.metrics import accuracy_score +from torchmetrics.classification import MultilabelF1Score +from LinWarCosAnLR import LinearWarmupCosineAnnealingLR + +from fusion import MutanFusion + +__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin" +os.environ["WANDB_START_METHOD"] = "thread" +wandb_api_key = os.environ["WANDB_API_KEY"] + + +class LitVisionEncoder(pl.LightningModule): + """ + Wrapper around a pytorch module, allowing this module to be used in automatic + training with pytorch lightning. + Among other things, the wrapper allows us to do automatic training and removes the + need to manage data on different devices (e.g. GPU and CPU). + """ + + def __init__( + self, + config: ConfigILM.ILMConfiguration, + lr: float = 1e-3, + ): + super().__init__() + self.lr = lr + self.config = config + self.model = ConfigILM.ConfigILM(config) + + def _disassemble_batch(self, batch): + images, questions, labels = batch + # transposing tensor, needed for Huggingface-Dataloader combination + questions = torch.tensor( + [x.tolist() for x in questions], device=self.device + ).T.int() + return (images, questions), labels + + def training_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + self.log("train/loss", loss) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01) + + # these are steps if interval is set to step + max_intervals = int(self.trainer.max_epochs * + len(self.trainer.datamodule.train_ds) / + self.trainer.datamodule.batch_size) + warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0 + + print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps") + + lr_scheduler = { + 'scheduler': LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=warmup, + max_epochs=max_intervals, + warmup_start_lr=self.lr / 10, + eta_min=self.lr / 10 + ), + 'name': 'learning_rate', + 'interval': "step", + 'frequency': 1 + } + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + return {"loss": loss, "outputs": x_hat, "labels": y} + + def validation_epoch_end(self, outputs): + metrics = self.get_metrics(outputs) + + self.log("val/loss", metrics["avg_loss"]) + self.log("val/f1", metrics["avg_f1_score"]) + self.log("val/Accuracy (LULC)", metrics["accuracy"]["LULC"]) + self.log("val/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) + self.log("val/Accuracy (Overall)", metrics["accuracy"]["Overall"]) + self.log("val/Accuracy (Average)", metrics["accuracy"]["Average"]) + + def test_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + return {"loss": loss, "outputs": x_hat, "labels": y} + + def test_epoch_end(self, outputs): + metrics = self.get_metrics(outputs) + + self.log("test/loss", metrics["avg_loss"]) + self.log("test/f1", metrics["avg_f1_score"]) + self.log("test/Accuracy (LULC)", metrics["accuracy"]["LULC"]) + self.log("test/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) + self.log("test/Accuracy (Overall)", metrics["accuracy"]["Overall"]) + self.log("test/Accuracy (Average)", metrics["accuracy"]["Average"]) + + def forward(self, batch): + # because we are a wrapper, we call the inner function manually + return self.model(batch) + + def get_metrics(self, outputs): + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + logits = torch.cat([x["outputs"].cpu() for x in outputs], 0) + labels = torch.cat( + [x["labels"].cpu() for x in outputs], 0 + ) # Tensor of size (#samples x classes) + + selected_answers = self.trainer.datamodule.selected_answers + + argmax_out = torch.argmax(logits, dim=1) + argmax_lbl = torch.argmax(labels, dim=1) + + # get answers and predictions per type + yn_preds = [] + yn_gts = [] + lulc_preds = [] + lulc_gts = [] + + for i, ans in enumerate(tqdm(argmax_lbl, desc="Counting answers")): + # Yes/No question + if selected_answers[ans] in ["yes", "no"]: + + # stored for global Yes/No + yn_preds.append(argmax_out[i]) + yn_gts.append(ans) + + # LC question + else: + # stored for global LC + lulc_preds.append(argmax_out[i]) + lulc_gts.append(ans) + + acc_yn = accuracy_score(yn_gts, yn_preds) + acc_lulc = accuracy_score(lulc_gts, lulc_preds) + + accuracy_dict = { + "Yes/No": acc_yn, + "LULC": acc_lulc, + "Overall": accuracy_score( + argmax_lbl, argmax_out + ), # micro average on classes + "Average": (acc_yn + acc_lulc) / 2, # macro average on types + } + + f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to( + logits.device + )(logits, labels) + + avg_f1_score = float( + torch.sum(f1_score) / self.config.classes + ) # macro average f1 score + + return { + "avg_loss": avg_loss, + "avg_f1_score": avg_f1_score, + "accuracy": accuracy_dict, + } + + +def overwrite_vision_weights(model, vision_checkpoint): + if vision_checkpoint is None: + return model + if not isfile(vision_checkpoint): + print("Pretrained vision model not available, cannot load checkpoint") + return model + # load weights + # get model and pretrained state dicts + if torch.cuda.is_available(): + pretrained_dict = torch.load(vision_checkpoint) + else: + pretrained_dict = torch.load( + vision_checkpoint, map_location=torch.device("cpu") + ) + model_dict = model.state_dict() + + # filter out unnecessary keys + # this allows to load lightning or pytorch model loading + if "pytorch-lightning_version" in pretrained_dict.keys(): + # checkpoint is a Pytorch-Lightning Checkpoint + pretrained_dict = { + k: v + for k, v in pretrained_dict["state_dict"].items() + if k in model_dict + } + else: + pretrained_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict + } + + # filter keys that have a size mismatch + mismatch_keys = [ + x + for x in pretrained_dict.keys() + if pretrained_dict[x].shape != model_dict[x].shape + ] + for key in mismatch_keys: + del pretrained_dict[key] + print(f"Key '{key}' size mismatch, removing from loading") + + # overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + + # load the new state dict + model.load_state_dict(model_dict) + print("Vision Model checkpoint loaded") + return model + + +def mutan(fusion_in: int, fusion_out: int): + opt = { + 'dim_hv': fusion_in, + 'dim_hq': fusion_in, + 'dim_mm': fusion_out, + 'dropout_hv': 0.1, + 'dropout_hq': 0.1, + 'R': 1 + } + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + f = MutanFusion(visual_embedding=False, + question_embedding=False, + opt=opt).to(device) + return f + + +def main( + vision_model: str = "resnet152", + text_model: str = "bert-base-uncased", + lr: float = 1e-3, + epochs: int = 100, + batch_size: int = 32, + seed: int = 42, + data_dir: str = None, + test_run: bool = False, + num_workers_dataloader: int = 4, + vision_checkpoint: str = None +): + if test_run: + max_img_index = 10 * batch_size + epochs = 10 + else: + max_img_index = -1 + + pl.seed_everything(seed, workers=True) + + img_size = 120 + channels = 10 + + fusion_in = 1200 + fusion_out = 360 + model_config = ILMConfiguration( + timm_model_name=vision_model, + hf_model_name=text_model, + classes=1000, + image_size=img_size, + channels=channels, + network_type=ILMType.VQA_CLASSIFICATION, + visual_features_out=2048, + fusion_in=fusion_in, + fusion_out=fusion_out, + fusion_method=mutan(fusion_in=fusion_in, fusion_out=fusion_out), + fusion_hidden=256 + ) + + # Key is available by wandb, project name can be chosen at will + wandb.login(key=wandb_api_key) + + tags = ["Training", vision_model, text_model] + if test_run: + tags += ["Test Run"] + wandb_logger = WandbLogger(project=f"LiT4RSVQA", + log_model=True, + tags=tags, # keyword arg directly to wandb.init() + ) + + monitor = "val/f1" + monitor_str = "F1_score" + # checkpointing + checkpoint_callback = ModelCheckpoint( + monitor="val/f1", + dirpath="./checkpoints", + filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" + + str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + + f"{monitor}" + ":.3f}", + auto_insert_metric_name=False, + save_top_k=1, + mode="max", + save_last=True + ) + early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00, + patience=25, verbose=False, mode="max") + lr_monitor = LearningRateMonitor(logging_interval='step') + + trainer = pl.Trainer( + max_epochs=epochs, + accelerator="auto", + log_every_n_steps=5, + logger=wandb_logger, + check_val_every_n_epoch=5, + callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], + + ) + + model = LitVisionEncoder(config=model_config, lr=lr) + model = overwrite_vision_weights(model, vision_checkpoint) + + hf_tokenizer, _ = get_huggingface_model( + model_name=text_model, load_pretrained_if_available=False + ) + dm = RSVQAxBENDataModule( + data_dir=resolve_ben_data_dir(data_dir=data_dir, force_mock=True), + img_size=(channels, img_size, img_size), + num_workers_dataloader=num_workers_dataloader, + batch_size=batch_size, + max_img_idx=max_img_index, + ) + + trainer.fit(model=model, datamodule=dm) + trainer.test(model=model, datamodule=dm, ckpt_path="best") + + wandb.finish() + print("=== Training finished ===") + + +if __name__ == "__main__": + typer.run(main) -- GitLab