From 658ce297f32671157a613543aeafee753b0d1856 Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Wed, 3 May 2023 09:52:41 +0200
Subject: [PATCH] adding flop analysis

---
 train_lit4rsvqa.py | 21 ++++++++++++++++-----
 train_rsvqa.py     | 14 ++++++++++++++
 2 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py
index 0e71df0..3076169 100644
--- a/train_lit4rsvqa.py
+++ b/train_lit4rsvqa.py
@@ -21,7 +21,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor
 from sklearn.metrics import accuracy_score
 from torchmetrics.classification import MultilabelF1Score
 from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
-
+from fvcore.nn import parameter_count, FlopCountAnalysis
 
 __author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
 os.environ["WANDB_START_METHOD"] = "thread"
@@ -37,15 +37,24 @@ class LitVisionEncoder(pl.LightningModule):
     """
 
     def __init__(
-        self,
-        config: ConfigILM.ILMConfiguration,
-        lr: float = 1e-3,
+            self,
+            config: ConfigILM.ILMConfiguration,
+            lr: float = 1e-3,
     ):
         super().__init__()
         self.lr = lr
         self.config = config
         self.model = ConfigILM.ConfigILM(config)
 
+    def get_stats(self):
+        # create example image
+        dummy_input = [torch.rand([1, self.config.channels, self.config.image_size,
+                                   self.config.image_size], device=self.device),
+                       torch.ones([1, 32], device=self.device, dtype=torch.int)]
+        params = parameter_count(self)
+        flops = FlopCountAnalysis(self, dummy_input)
+        return {"flops": flops.total(), "params": params['']}
+
     def _disassemble_batch(self, batch):
         images, questions, labels = batch
         # transposing tensor, needed for Huggingface-Dataloader combination
@@ -86,7 +95,6 @@ class LitVisionEncoder(pl.LightningModule):
         }
         return [optimizer], [lr_scheduler]
 
-
     def validation_step(self, batch, batch_idx):
         x, y = self._disassemble_batch(batch)
         x_hat = self.model(x)
@@ -307,6 +315,9 @@ def main(
     model = LitVisionEncoder(config=model_config, lr=lr)
     model = overwrite_vision_weights(model, vision_checkpoint)
 
+    print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n"
+          f"              Flops: {model.get_stats()['flops']:15,d}")
+
     hf_tokenizer, _ = get_huggingface_model(
         model_name=text_model, load_pretrained_if_available=False
     )
diff --git a/train_rsvqa.py b/train_rsvqa.py
index 2286005..22cd606 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -21,6 +21,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor
 from sklearn.metrics import accuracy_score
 from torchmetrics.classification import MultilabelF1Score
 from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
+from fvcore.nn import parameter_count, FlopCountAnalysis
+
 
 from fusion import MutanFusion
 
@@ -47,6 +49,15 @@ class LitVisionEncoder(pl.LightningModule):
         self.config = config
         self.model = ConfigILM.ConfigILM(config)
 
+    def get_stats(self):
+        # create example image
+        dummy_input = [torch.rand([1, self.config.channels, self.config.image_size,
+                                   self.config.image_size], device=self.device),
+                       torch.ones([1, 32], device=self.device, dtype=torch.int)]
+        params = parameter_count(self)
+        flops = FlopCountAnalysis(self, dummy_input)
+        return {"flops": flops.total(), "params": params['']}
+
     def _disassemble_batch(self, batch):
         images, questions, labels = batch
         # transposing tensor, needed for Huggingface-Dataloader combination
@@ -330,6 +341,9 @@ def main(
     model = LitVisionEncoder(config=model_config, lr=lr)
     model = overwrite_vision_weights(model, vision_checkpoint)
 
+    print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n"
+          f"              Flops: {model.get_stats()['flops']:15,d}")
+
     hf_tokenizer, _ = get_huggingface_model(
         model_name=text_model, load_pretrained_if_available=False
     )
-- 
GitLab