Skip to content
Snippets Groups Projects
Commit dee6db5b authored by Leonard Wayne Hackel's avatar Leonard Wayne Hackel
Browse files

adding mock ds

parent a1b0093f
No related branches found
No related tags found
No related merge requests found
import torch
from lightning.pytorch import LightningDataModule
from torch.utils.data.dataset import Dataset
class MockDataset(Dataset):
def __init__(self, dims, clss, length):
self.data = torch.rand(*dims)
self.targets = torch.rand(clss)
self.length = length
def __getitem__(self, index):
return self.data, self.targets
def __len__(self):
return self.length
class MockDataModule(LightningDataModule):
def __init__(self, dims, clss, train_length, val_length, test_length, bs):
super().__init__()
self.dims = dims
self.clss = clss
self.length = train_length
self.val_length = val_length
self.test_length = test_length
self.batch_size = bs
self.train_ds = MockDataset(self.dims, self.clss, self.length)
self.val_ds = MockDataset(self.dims, self.clss, self.val_length)
self.test_ds = MockDataset(self.dims, self.clss, self.test_length)
print("MockDataModule initialized")
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)
def test_dataloader(self):
return torch.utils.data.DataLoader(self.test_ds, batch_size=self.batch_size)
......@@ -14,6 +14,7 @@ from configilm.extra.DataModules.BENv2_DataModule import BENv2DataModule
from lightning.pytorch.loggers import WandbLogger
from ben_publication.BENv2ImageClassifier import BENv2ImageEncoder
from ben_publication.mock_dm import MockDataModule
__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
......@@ -114,12 +115,23 @@ def main(
data_dirs = resolve_data_dir(data_dirs, allow_mock=True)
print(f"Using data directories for {hostname}")
dm = BENv2DataModule(
data_dirs=data_dirs,
batch_size=bs,
num_workers_dataloader=workers,
img_size=(channels, 120, 120),
)
use_mock_data = True
if use_mock_data:
dm = MockDataModule(
dims=(channels, img_size, img_size),
clss=num_classes,
train_length=200_000,
val_length=120_000,
test_length=120_000,
bs=bs,
)
else:
dm = BENv2DataModule(
data_dirs=data_dirs,
batch_size=bs,
num_workers_dataloader=workers,
img_size=(channels, img_size, img_size),
)
# fixed model parameters based on the BigEarthNet v2.0 dataset
config = ILMConfiguration(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment