diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index bd24a695dd265853b97313c297df146b37da9bc2..6117e57893f592834c848211c48e008905604695 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -249,13 +249,15 @@ def main( data_dir: str = None, test_run: bool = False, num_workers_dataloader: int = 4, - vision_checkpoint: str = None + vision_checkpoint: str = None, + matmul_precision: str = "medium", ): if test_run: max_img_index = 10 * batch_size epochs = 10 else: max_img_index = -1 + torch.set_float32_matmul_precision(matmul_precision) pl.seed_everything(seed, workers=True) @@ -340,7 +342,8 @@ def main( "Seed": seed, "# Workers": num_workers_dataloader, "Vision Checkpoint": vision_checkpoint, - "GPU": torch.cuda.get_device_name() + "GPU": torch.cuda.get_device_name(), + "MatMul Precision": matmul_precision, } )