diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index 0e71df073311fb0eeedd9223ed2105765cd76ffa..3076169ba85337df7a2c5e7b59a1dfceca232a62 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -21,7 +21,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor from sklearn.metrics import accuracy_score from torchmetrics.classification import MultilabelF1Score from LinWarCosAnLR import LinearWarmupCosineAnnealingLR - +from fvcore.nn import parameter_count, FlopCountAnalysis __author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin" os.environ["WANDB_START_METHOD"] = "thread" @@ -37,15 +37,24 @@ class LitVisionEncoder(pl.LightningModule): """ def __init__( - self, - config: ConfigILM.ILMConfiguration, - lr: float = 1e-3, + self, + config: ConfigILM.ILMConfiguration, + lr: float = 1e-3, ): super().__init__() self.lr = lr self.config = config self.model = ConfigILM.ConfigILM(config) + def get_stats(self): + # create example image + dummy_input = [torch.rand([1, self.config.channels, self.config.image_size, + self.config.image_size], device=self.device), + torch.ones([1, 32], device=self.device, dtype=torch.int)] + params = parameter_count(self) + flops = FlopCountAnalysis(self, dummy_input) + return {"flops": flops.total(), "params": params['']} + def _disassemble_batch(self, batch): images, questions, labels = batch # transposing tensor, needed for Huggingface-Dataloader combination @@ -86,7 +95,6 @@ class LitVisionEncoder(pl.LightningModule): } return [optimizer], [lr_scheduler] - def validation_step(self, batch, batch_idx): x, y = self._disassemble_batch(batch) x_hat = self.model(x) @@ -307,6 +315,9 @@ def main( model = LitVisionEncoder(config=model_config, lr=lr) model = overwrite_vision_weights(model, vision_checkpoint) + print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n" + f" Flops: {model.get_stats()['flops']:15,d}") + hf_tokenizer, _ = get_huggingface_model( model_name=text_model, load_pretrained_if_available=False ) diff --git a/train_rsvqa.py b/train_rsvqa.py index 2286005d0e5eaa3532579677287a37be45e6b392..22cd60606a167e0caea3d9d5a2bb4b18c3f4288a 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -21,6 +21,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor from sklearn.metrics import accuracy_score from torchmetrics.classification import MultilabelF1Score from LinWarCosAnLR import LinearWarmupCosineAnnealingLR +from fvcore.nn import parameter_count, FlopCountAnalysis + from fusion import MutanFusion @@ -47,6 +49,15 @@ class LitVisionEncoder(pl.LightningModule): self.config = config self.model = ConfigILM.ConfigILM(config) + def get_stats(self): + # create example image + dummy_input = [torch.rand([1, self.config.channels, self.config.image_size, + self.config.image_size], device=self.device), + torch.ones([1, 32], device=self.device, dtype=torch.int)] + params = parameter_count(self) + flops = FlopCountAnalysis(self, dummy_input) + return {"flops": flops.total(), "params": params['']} + def _disassemble_batch(self, batch): images, questions, labels = batch # transposing tensor, needed for Huggingface-Dataloader combination @@ -330,6 +341,9 @@ def main( model = LitVisionEncoder(config=model_config, lr=lr) model = overwrite_vision_weights(model, vision_checkpoint) + print(f"Model Stats: Params: {model.get_stats()['params']:15,d}\n" + f" Flops: {model.get_stats()['flops']:15,d}") + hf_tokenizer, _ = get_huggingface_model( model_name=text_model, load_pretrained_if_available=False )