diff --git a/train_rsvqa.py b/train_rsvqa.py index e34232fdcc0cc7213901a353d98d5ad53ebbd052..4a02025c13dea8e8551caea3e4722735028c3d1f 100644 --- a/train_rsvqa.py +++ b/train_rsvqa.py @@ -334,7 +334,7 @@ def main( model_name=text_model, load_pretrained_if_available=False ) dm = RSVQAxBENDataModule( - data_dir=resolve_ben_data_dir(data_dir=data_dir, force_mock=True), + data_dir=resolve_ben_data_dir(data_dir=data_dir), img_size=(channels, img_size, img_size), num_workers_dataloader=num_workers_dataloader, batch_size=batch_size,