diff --git a/train_rsvqa.py b/train_rsvqa.py
index 05e2907de206024f5f121e6f144aec46c072c321..47ac1b2e20e377f60c80180b2716e4e37c1d95af 100644
--- a/train_rsvqa.py
+++ b/train_rsvqa.py
@@ -326,7 +326,7 @@ def main(
         logger=wandb_logger,
         check_val_every_n_epoch=2,
         callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
-
+        precision="16-mixed",
     )
 
     model = LitVisionEncoder(config=model_config, lr=lr)