diff --git a/train_rsvqa.py b/train_rsvqa.py index 47ac1b2e20e377f60c80180b2716e4e37c1d95af..446105777eaf0f83edf5047851d337b9bd3de50a 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -326,7 +326,7 @@ def main( logger=wandb_logger, check_val_every_n_epoch=2, callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], - precision="16-mixed", + precision="16", ) model = LitVisionEncoder(config=model_config, lr=lr)