From 19ed324e69b4323933630ff608cddd557233cdb4 Mon Sep 17 00:00:00 2001 From: Leonard Hackel <l.hackel@tu-berlin.de> Date: Mon, 22 May 2023 09:32:31 +0200 Subject: [PATCH] adding precision to matmul op for torch>1.12 --- train_lit4rsvqa.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py index bd24a69..6117e57 100644 --- a/train_lit4rsvqa.py +++ b/train_lit4rsvqa.py @@ -249,13 +249,15 @@ def main( data_dir: str = None, test_run: bool = False, num_workers_dataloader: int = 4, - vision_checkpoint: str = None + vision_checkpoint: str = None, + matmul_precision: str = "medium", ): if test_run: max_img_index = 10 * batch_size epochs = 10 else: max_img_index = -1 + torch.set_float32_matmul_precision(matmul_precision) pl.seed_everything(seed, workers=True) @@ -340,7 +342,8 @@ def main( "Seed": seed, "# Workers": num_workers_dataloader, "Vision Checkpoint": vision_checkpoint, - "GPU": torch.cuda.get_device_name() + "GPU": torch.cuda.get_device_name(), + "MatMul Precision": matmul_precision, } ) -- GitLab