From db77dec23a542a90eed685b34005adc23a3f93d7 Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Wed, 3 May 2023 14:05:07 +0200
Subject: [PATCH] removing flop count that does not work

---
 train_rsvqa.py | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/train_rsvqa.py b/train_rsvqa.py
index 22cd606..1e2beab 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -49,15 +49,6 @@ class LitVisionEncoder(pl.LightningModule):
         self.config = config
         self.model = ConfigILM.ConfigILM(config)
 
-    def get_stats(self):
-        # create example image
-        dummy_input = [torch.rand([1, self.config.channels, self.config.image_size,
-                                   self.config.image_size], device=self.device),
-                       torch.ones([1, 32], device=self.device, dtype=torch.int)]
-        params = parameter_count(self)
-        flops = FlopCountAnalysis(self, dummy_input)
-        return {"flops": flops.total(), "params": params['']}
-
     def _disassemble_batch(self, batch):
         images, questions, labels = batch
         # transposing tensor, needed for Huggingface-Dataloader combination
@@ -341,9 +332,6 @@ def main(
     model = LitVisionEncoder(config=model_config, lr=lr)
     model = overwrite_vision_weights(model, vision_checkpoint)
 
-    print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n"
-          f"              Flops: {model.get_stats()['flops']:15,d}")
-
     hf_tokenizer, _ = get_huggingface_model(
         model_name=text_model, load_pretrained_if_available=False
     )
-- 
GitLab