Skip to content
Snippets Groups Projects
Commit 19ed324e authored by Leonard Wayne Hackel's avatar Leonard Wayne Hackel
Browse files

adding precision to matmul op for torch>1.12

parent b517a1ad
No related branches found
No related tags found
No related merge requests found
...@@ -249,13 +249,15 @@ def main( ...@@ -249,13 +249,15 @@ def main(
data_dir: str = None, data_dir: str = None,
test_run: bool = False, test_run: bool = False,
num_workers_dataloader: int = 4, num_workers_dataloader: int = 4,
vision_checkpoint: str = None vision_checkpoint: str = None,
matmul_precision: str = "medium",
): ):
if test_run: if test_run:
max_img_index = 10 * batch_size max_img_index = 10 * batch_size
epochs = 10 epochs = 10
else: else:
max_img_index = -1 max_img_index = -1
torch.set_float32_matmul_precision(matmul_precision)
pl.seed_everything(seed, workers=True) pl.seed_everything(seed, workers=True)
...@@ -340,7 +342,8 @@ def main( ...@@ -340,7 +342,8 @@ def main(
"Seed": seed, "Seed": seed,
"# Workers": num_workers_dataloader, "# Workers": num_workers_dataloader,
"Vision Checkpoint": vision_checkpoint, "Vision Checkpoint": vision_checkpoint,
"GPU": torch.cuda.get_device_name() "GPU": torch.cuda.get_device_name(),
"MatMul Precision": matmul_precision,
} }
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment