diff --git a/pretrain_lit4rsvqa.py b/pretrain_lit4rsvqa.py index 3e204fe8716d3f8da5a8cebfd0e6d216de9ec26d..59a86c4174af8f7b526cad1672c2acf422c40a93 100644 --- a/pretrain_lit4rsvqa.py +++ b/pretrain_lit4rsvqa.py @@ -147,7 +147,8 @@ def main( batch_size: int = 32, seed: int = 42, data_dir: str = None, - test_run: bool = False + test_run: bool = False, + num_workers_dataloader: int = 4 ): if test_run: max_img_index = 10 * batch_size @@ -212,7 +213,7 @@ def main( dm = BENDataModule( data_dir=resolve_ben_data_dir(data_dir=data_dir), img_size=(channels, img_size, img_size), - num_workers_dataloader=4, + num_workers_dataloader=num_workers_dataloader, batch_size=batch_size, max_img_idx=max_img_index, )