From 658ce297f32671157a613543aeafee753b0d1856 Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Wed, 3 May 2023 09:52:41 +0200 Subject: [PATCH] adding flop analysis --- train_lit4rsvqa.py | 21 ++++++++++++++++----- train_rsvqa.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index 0e71df0..3076169 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -21,7 +21,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor from sklearn.metrics import accuracy_score from torchmetrics.classification import MultilabelF1Score from LinWarCosAnLR import LinearWarmupCosineAnnealingLR - +from fvcore.nn import parameter_count, FlopCountAnalysis __author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin" os.environ["WANDB_START_METHOD"] = "thread" @@ -37,15 +37,24 @@ class LitVisionEncoder(pl.LightningModule): """ def __init__( - self, - config: ConfigILM.ILMConfiguration, - lr: float = 1e-3, + self, + config: ConfigILM.ILMConfiguration, + lr: float = 1e-3, ): super().__init__() self.lr = lr self.config = config self.model = ConfigILM.ConfigILM(config) + def get_stats(self): + # create example image + dummy_input = [torch.rand([1, self.config.channels, self.config.image_size, + self.config.image_size], device=self.device), + torch.ones([1, 32], device=self.device, dtype=torch.int)] + params = parameter_count(self) + flops = FlopCountAnalysis(self, dummy_input) + return {"flops": flops.total(), "params": params['']} + def _disassemble_batch(self, batch): images, questions, labels = batch # transposing tensor, needed for Huggingface-Dataloader combination @@ -86,7 +95,6 @@ class LitVisionEncoder(pl.LightningModule): } return [optimizer], [lr_scheduler] - def validation_step(self, batch, batch_idx): x, y = self._disassemble_batch(batch) x_hat = self.model(x) @@ -307,6 +315,9 @@ def main( model = LitVisionEncoder(config=model_config, lr=lr) model = overwrite_vision_weights(model, vision_checkpoint) + print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n" + f" Flops: {model.get_stats()['flops']:15,d}") + hf_tokenizer, _ = get_huggingface_model( model_name=text_model, load_pretrained_if_available=False ) diff --git a/train_rsvqa.py b/train_rsvqa.py index 2286005..22cd606 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -21,6 +21,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor from sklearn.metrics import accuracy_score from torchmetrics.classification import MultilabelF1Score from LinWarCosAnLR import LinearWarmupCosineAnnealingLR +from fvcore.nn import parameter_count, FlopCountAnalysis + from fusion import MutanFusion @@ -47,6 +49,15 @@ class LitVisionEncoder(pl.LightningModule): self.config = config self.model = ConfigILM.ConfigILM(config) + def get_stats(self): + # create example image + dummy_input = [torch.rand([1, self.config.channels, self.config.image_size, + self.config.image_size], device=self.device), + torch.ones([1, 32], device=self.device, dtype=torch.int)] + params = parameter_count(self) + flops = FlopCountAnalysis(self, dummy_input) + return {"flops": flops.total(), "params": params['']} + def _disassemble_batch(self, batch): images, questions, labels = batch # transposing tensor, needed for Huggingface-Dataloader combination @@ -330,6 +341,9 @@ def main( model = LitVisionEncoder(config=model_config, lr=lr) model = overwrite_vision_weights(model, vision_checkpoint) + print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n" + f" Flops: {model.get_stats()['flops']:15,d}") + hf_tokenizer, _ = get_huggingface_model( model_name=text_model, load_pretrained_if_available=False ) -- GitLab