From 19ed324e69b4323933630ff608cddd557233cdb4 Mon Sep 17 00:00:00 2001
From: Leonard Hackel <l.hackel@tu-berlin.de>
Date: Mon, 22 May 2023 09:32:31 +0200
Subject: [PATCH] adding precision to matmul op for torch>1.12

---
 train_lit4rsvqa.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py
index bd24a69..6117e57 100644
--- a/train_lit4rsvqa.py
+++ b/train_lit4rsvqa.py
@@ -249,13 +249,15 @@ def main(
         data_dir: str = None,
         test_run: bool = False,
         num_workers_dataloader: int = 4,
-        vision_checkpoint: str = None
+        vision_checkpoint: str = None,
+        matmul_precision: str = "medium",
 ):
     if test_run:
         max_img_index = 10 * batch_size
         epochs = 10
     else:
         max_img_index = -1
+    torch.set_float32_matmul_precision(matmul_precision)
 
     pl.seed_everything(seed, workers=True)
 
@@ -340,7 +342,8 @@ def main(
             "Seed": seed,
             "# Workers": num_workers_dataloader,
             "Vision Checkpoint": vision_checkpoint,
-            "GPU": torch.cuda.get_device_name()
+            "GPU": torch.cuda.get_device_name(),
+            "MatMul Precision": matmul_precision,
         }
     )
 
-- 
GitLab