diff --git a/train_lit4rsvqa.py b/train_lit4rsvqa.py
index bd24a695dd265853b97313c297df146b37da9bc2..6117e57893f592834c848211c48e008905604695 100644
--- a/train_lit4rsvqa.py
+++ b/train_lit4rsvqa.py
@@ -249,13 +249,15 @@ def main(
         data_dir: str = None,
         test_run: bool = False,
         num_workers_dataloader: int = 4,
-        vision_checkpoint: str = None
+        vision_checkpoint: str = None,
+        matmul_precision: str = "medium",
 ):
     if test_run:
         max_img_index = 10 * batch_size
         epochs = 10
     else:
         max_img_index = -1
+    torch.set_float32_matmul_precision(matmul_precision)
 
     pl.seed_everything(seed, workers=True)
 
@@ -340,7 +342,8 @@ def main(
             "Seed": seed,
             "# Workers": num_workers_dataloader,
             "Vision Checkpoint": vision_checkpoint,
-            "GPU": torch.cuda.get_device_name()
+            "GPU": torch.cuda.get_device_name(),
+            "MatMul Precision": matmul_precision,
         }
     )