From 99800bc65bb78894d8e82ec562dbe01647abc394 Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Wed, 26 Apr 2023 15:32:42 +0200 Subject: [PATCH] adding training script and gitingore --- .gitignore | 4 + train_lit4rsvqa.py | 327 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37e1168 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +.idea/ +checkpoints/ +wandb/ diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index e69de29..04975c2 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -0,0 +1,327 @@ +# import packages +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch import optim +from tqdm import tqdm + +from configilm import ConfigILM +from configilm.ConfigILM import ILMConfiguration, ILMType +from configilm.ConfigILM import get_hf_model as get_huggingface_model +from configilm.extra.RSVQAxBEN_DataModule_LMDB_Encoder import RSVQAxBENDataModule +from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir +import typer +import os +from os.path import isfile +import wandb +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import LearningRateMonitor +from sklearn.metrics import accuracy_score +from torchmetrics.classification import MultilabelF1Score +from LinWarCosAnLR import LinearWarmupCosineAnnealingLR + + +__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin" +os.environ["WANDB_START_METHOD"] = "thread" +wandb_api_key = os.environ["WANDB_API_KEY"] + + +class LitVisionEncoder(pl.LightningModule): + """ + Wrapper around a pytorch module, allowing this module to be used in automatic + training with pytorch lightning. + Among other things, the wrapper allows us to do automatic training and removes the + need to manage data on different devices (e.g. GPU and CPU). + """ + + def __init__( + self, + config: ConfigILM.ILMConfiguration, + lr: float = 1e-3, + ): + super().__init__() + self.lr = lr + self.config = config + self.model = ConfigILM.ConfigILM(config) + + def _disassemble_batch(self, batch): + images, questions, labels = batch + # transposing tensor, needed for Huggingface-Dataloader combination + questions = torch.tensor( + [x.tolist() for x in questions], device=self.device + ).T.int() + return (images, questions), labels + + def training_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + self.log("train/loss", loss) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01) + + # these are steps if interval is set to step + max_intervals = int(self.trainer.max_epochs * + len(self.trainer.datamodule.train_ds) / + self.trainer.datamodule.batch_size) + warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0 + + print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps") + + lr_scheduler = { + 'scheduler': LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=warmup, + max_epochs=max_intervals, + warmup_start_lr=self.lr / 10, + eta_min=self.lr / 10 + ), + 'name': 'learning_rate', + 'interval': "step", + 'frequency': 1 + } + return [optimizer], [lr_scheduler] + + + def validation_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + return {"loss": loss, "outputs": x_hat, "labels": y} + + def validation_epoch_end(self, outputs): + metrics = self.get_metrics(outputs) + + self.log("val/loss", metrics["avg_loss"]) + self.log("val/f1", metrics["avg_f1_score"]) + self.log("val/Accuracy (LULC)", metrics["accuracy"]["LULC"]) + self.log("val/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) + self.log("val/Accuracy (Overall)", metrics["accuracy"]["Overall"]) + self.log("val/Accuracy (Average)", metrics["accuracy"]["Average"]) + + def test_step(self, batch, batch_idx): + x, y = self._disassemble_batch(batch) + x_hat = self.model(x) + loss = F.binary_cross_entropy_with_logits(x_hat, y) + return {"loss": loss, "outputs": x_hat, "labels": y} + + def test_epoch_end(self, outputs): + metrics = self.get_metrics(outputs) + + self.log("test/loss", metrics["avg_loss"]) + self.log("test/f1", metrics["avg_f1_score"]) + self.log("test/Accuracy (LULC)", metrics["accuracy"]["LULC"]) + self.log("test/Accuracy (Yes-No)", metrics["accuracy"]["Yes/No"]) + self.log("test/Accuracy (Overall)", metrics["accuracy"]["Overall"]) + self.log("test/Accuracy (Average)", metrics["accuracy"]["Average"]) + + def forward(self, batch): + # because we are a wrapper, we call the inner function manually + return self.model(batch) + + def get_metrics(self, outputs): + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + logits = torch.cat([x["outputs"].cpu() for x in outputs], 0) + labels = torch.cat( + [x["labels"].cpu() for x in outputs], 0 + ) # Tensor of size (#samples x classes) + + selected_answers = self.trainer.datamodule.selected_answers + + argmax_out = torch.argmax(logits, dim=1) + argmax_lbl = torch.argmax(labels, dim=1) + + # get answers and predictions per type + yn_preds = [] + yn_gts = [] + lulc_preds = [] + lulc_gts = [] + + for i, ans in enumerate(tqdm(argmax_lbl, desc="Counting answers")): + # Yes/No question + if selected_answers[ans] in ["yes", "no"]: + + # stored for global Yes/No + yn_preds.append(argmax_out[i]) + yn_gts.append(ans) + + # LC question + else: + # stored for global LC + lulc_preds.append(argmax_out[i]) + lulc_gts.append(ans) + + acc_yn = accuracy_score(yn_gts, yn_preds) + acc_lulc = accuracy_score(lulc_gts, lulc_preds) + + accuracy_dict = { + "Yes/No": acc_yn, + "LULC": acc_lulc, + "Overall": accuracy_score( + argmax_lbl, argmax_out + ), # micro average on classes + "Average": (acc_yn + acc_lulc) / 2, # macro average on types + } + + f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to( + logits.device + )(logits, labels) + + avg_f1_score = float( + torch.sum(f1_score) / self.config.classes + ) # macro average f1 score + + return { + "avg_loss": avg_loss, + "avg_f1_score": avg_f1_score, + "accuracy": accuracy_dict, + } + + +def overwrite_vision_weights(model, vision_checkpoint): + if vision_checkpoint is None: + return model + if not isfile(vision_checkpoint): + print("Pretrained vision model not available, cannot load checkpoint") + return model + # load weights + # get model and pretrained state dicts + if torch.cuda.is_available(): + pretrained_dict = torch.load(vision_checkpoint) + else: + pretrained_dict = torch.load( + vision_checkpoint, map_location=torch.device("cpu") + ) + model_dict = model.state_dict() + + # filter out unnecessary keys + # this allows to load lightning or pytorch model loading + if "pytorch-lightning_version" in pretrained_dict.keys(): + # checkpoint is a Pytorch-Lightning Checkpoint + pretrained_dict = { + k: v + for k, v in pretrained_dict["state_dict"].items() + if k in model_dict + } + else: + pretrained_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict + } + + # filter keys that have a size mismatch + mismatch_keys = [ + x + for x in pretrained_dict.keys() + if pretrained_dict[x].shape != model_dict[x].shape + ] + for key in mismatch_keys: + del pretrained_dict[key] + print(f"Key '{key}' size mismatch, removing from loading") + + # overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + + # load the new state dict + model.load_state_dict(model_dict) + print("Vision Model checkpoint loaded") + return model + + +def main( + vision_model: str = "mobilevit_s", + text_model: str = "prajjwal1/bert-tiny", + lr: float = 1e-3, + epochs: int = 100, + batch_size: int = 32, + seed: int = 42, + data_dir: str = None, + test_run: bool = False, + num_workers_dataloader: int = 4, + vision_checkpoint: str = None +): + if test_run: + max_img_index = 10 * batch_size + epochs = 10 + else: + max_img_index = -1 + + pl.seed_everything(seed, workers=True) + + img_size = 120 + channels = 10 + + model_config = ILMConfiguration( + timm_model_name=vision_model, + hf_model_name=text_model, + classes=1000, + image_size=img_size, + channels=channels, + network_type=ILMType.VQA_CLASSIFICATION + ) + + # Key is available by wandb, project name can be chosen at will + wandb.login(key=wandb_api_key) + + tags = ["Training", vision_model, text_model] + if test_run: + tags += ["Test Run"] + wandb_logger = WandbLogger(project=f"LiT4RSVQA", + log_model=True, + tags=tags, # keyword arg directly to wandb.init() + ) + + monitor = "val/f1" + monitor_str = "F1_score" + # checkpointing + checkpoint_callback = ModelCheckpoint( + monitor="val/f1", + dirpath="./checkpoints", + filename=f"{wandb_logger.experiment.name}-{vision_model}-{text_model}-seed=" + + str(seed) + "-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + + f"{monitor}" + ":.3f}", + auto_insert_metric_name=False, + save_top_k=1, + mode="max", + save_last=True + ) + early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00, + patience=25, verbose=False, mode="max") + lr_monitor = LearningRateMonitor(logging_interval='step') + + trainer = pl.Trainer( + max_epochs=epochs, + accelerator="auto", + log_every_n_steps=5, + logger=wandb_logger, + check_val_every_n_epoch=5, + callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], + + ) + + model = LitVisionEncoder(config=model_config, lr=lr) + model = overwrite_vision_weights(model, vision_checkpoint) + + hf_tokenizer, _ = get_huggingface_model( + model_name=text_model, load_pretrained_if_available=False + ) + dm = RSVQAxBENDataModule( + data_dir=resolve_ben_data_dir(data_dir=data_dir), + img_size=(channels, img_size, img_size), + num_workers_dataloader=num_workers_dataloader, + batch_size=batch_size, + max_img_idx=max_img_index, + ) + + trainer.fit(model=model, datamodule=dm) + trainer.test(model=model, datamodule=dm, ckpt_path="best") + + wandb.finish() + print("=== Training finished ===") + + +if __name__ == "__main__": + typer.run(main) -- GitLab