From db77dec23a542a90eed685b34005adc23a3f93d7 Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Wed, 3 May 2023 14:05:07 +0200 Subject: [PATCH] removing flop count that does not work --- train_rsvqa.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/train_rsvqa.py b/train_rsvqa.py index 22cd606..1e2beab 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -49,15 +49,6 @@ class LitVisionEncoder(pl.LightningModule): self.config = config self.model = ConfigILM.ConfigILM(config) - def get_stats(self): - # create example image - dummy_input = [torch.rand([1, self.config.channels, self.config.image_size, - self.config.image_size], device=self.device), - torch.ones([1, 32], device=self.device, dtype=torch.int)] - params = parameter_count(self) - flops = FlopCountAnalysis(self, dummy_input) - return {"flops": flops.total(), "params": params['']} - def _disassemble_batch(self, batch): images, questions, labels = batch # transposing tensor, needed for Huggingface-Dataloader combination @@ -341,9 +332,6 @@ def main( model = LitVisionEncoder(config=model_config, lr=lr) model = overwrite_vision_weights(model, vision_checkpoint) - print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n" - f" Flops: {model.get_stats()['flops']:15,d}") - hf_tokenizer, _ = get_huggingface_model( model_name=text_model, load_pretrained_if_available=False ) -- GitLab