diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py
index 0e71df073311fb0eeedd9223ed2105765cd76ffa..3076169ba85337df7a2c5e7b59a1dfceca232a62 100644
--- a/train_lit4rsvqa.py
+++ b/train_lit4rsvqa.py
@@ -21,7 +21,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor
 from sklearn.metrics import accuracy_score
 from torchmetrics.classification import MultilabelF1Score
 from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
-
+from fvcore.nn import parameter_count, FlopCountAnalysis
 
 __author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
 os.environ["WANDB_START_METHOD"] = "thread"
@@ -37,15 +37,24 @@ class LitVisionEncoder(pl.LightningModule):
     """
 
     def __init__(
-        self,
-        config: ConfigILM.ILMConfiguration,
-        lr: float = 1e-3,
+            self,
+            config: ConfigILM.ILMConfiguration,
+            lr: float = 1e-3,
     ):
         super().__init__()
         self.lr = lr
         self.config = config
         self.model = ConfigILM.ConfigILM(config)
 
+    def get_stats(self):
+        # create example image
+        dummy_input = [torch.rand([1, self.config.channels, self.config.image_size,
+                                   self.config.image_size], device=self.device),
+                       torch.ones([1, 32], device=self.device, dtype=torch.int)]
+        params = parameter_count(self)
+        flops = FlopCountAnalysis(self, dummy_input)
+        return {"flops": flops.total(), "params": params['']}
+
     def _disassemble_batch(self, batch):
         images, questions, labels = batch
         # transposing tensor, needed for Huggingface-Dataloader combination
@@ -86,7 +95,6 @@ class LitVisionEncoder(pl.LightningModule):
         }
         return [optimizer], [lr_scheduler]
 
-
     def validation_step(self, batch, batch_idx):
         x, y = self._disassemble_batch(batch)
         x_hat = self.model(x)
@@ -307,6 +315,9 @@ def main(
     model = LitVisionEncoder(config=model_config, lr=lr)
     model = overwrite_vision_weights(model, vision_checkpoint)
 
+    print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n"
+          f"              Flops: {model.get_stats()['flops']:15,d}")
+
     hf_tokenizer, _ = get_huggingface_model(
         model_name=text_model, load_pretrained_if_available=False
     )
diff --git a/train_rsvqa.py b/train_rsvqa.py
index 2286005d0e5eaa3532579677287a37be45e6b392..22cd60606a167e0caea3d9d5a2bb4b18c3f4288a 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -21,6 +21,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor
 from sklearn.metrics import accuracy_score
 from torchmetrics.classification import MultilabelF1Score
 from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
+from fvcore.nn import parameter_count, FlopCountAnalysis
+
 
 from fusion import MutanFusion
 
@@ -47,6 +49,15 @@ class LitVisionEncoder(pl.LightningModule):
         self.config = config
         self.model = ConfigILM.ConfigILM(config)
 
+    def get_stats(self):
+        # create example image
+        dummy_input = [torch.rand([1, self.config.channels, self.config.image_size,
+                                   self.config.image_size], device=self.device),
+                       torch.ones([1, 32], device=self.device, dtype=torch.int)]
+        params = parameter_count(self)
+        flops = FlopCountAnalysis(self, dummy_input)
+        return {"flops": flops.total(), "params": params['']}
+
     def _disassemble_batch(self, batch):
         images, questions, labels = batch
         # transposing tensor, needed for Huggingface-Dataloader combination
@@ -330,6 +341,9 @@ def main(
     model = LitVisionEncoder(config=model_config, lr=lr)
     model = overwrite_vision_weights(model, vision_checkpoint)
 
+    print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n"
+          f"              Flops: {model.get_stats()['flops']:15,d}")
+
     hf_tokenizer, _ = get_huggingface_model(
         model_name=text_model, load_pretrained_if_available=False
     )