Skip to content
Snippets Groups Projects
Commit 31664f3a authored by Leonard Wayne Hackel's avatar Leonard Wayne Hackel
Browse files

going back to cuda 12, removing hf mixins

parent 15bdd384
No related branches found
No related tags found
No related merge requests found
......@@ -7,18 +7,17 @@ from typing import List
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from configilm import ConfigILM
from configilm.ConfigILM import ILMConfiguration
from configilm.ConfigILM import ILMType
from configilm.extra.BENv2_utils import NEW_LABELS
from configilm.extra.CustomTorchClasses import LinearWarmupCosineAnnealingLR
from configilm.metrics import get_classification_metric_collection
from huggingface_hub import PyTorchModelHubMixin
__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
# class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
class BENv2ImageEncoder(pl.LightningModule):
"""
Wrapper around a pytorch module, allowing this module to be used in automatic
training with pytorch lightning.
......@@ -39,7 +38,7 @@ class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
assert config.classes == 19
self.mock_params = torch.nn.Linear(config.classes, config.classes)
self.model = lambda x: self.mock_params(torch.rand((x.shape[0], config.classes)).to(x.device))
#self.model = ConfigILM.ConfigILM(config)
# self.model = ConfigILM.ConfigILM(config)
self.val_output_list: List[dict] = []
self.test_output_list: List[dict] = []
self.loss = torch.nn.BCEWithLogitsLoss()
......@@ -114,7 +113,6 @@ class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
super().on_validation_epoch_start()
self.val_output_list = []
def on_validation_epoch_end(self):
avg_loss = torch.stack([x["loss"] for x in self.val_output_list]).mean()
self.log("val/loss", avg_loss)
......
This diff is collapsed.
......@@ -10,14 +10,6 @@ python = ">=3.10, <3.12"
configilm = { extras = ["full"], version = "^0.6.3" }
wandb = "^0.17.1"
numpy = "^1.26.4"
torch = {version = "^2.3.1+cu118", source = "pytorch-gpu-src"}
torchvision = {version = "^0.18.1+cu118", source = "pytorch-gpu-src"}
torchaudio = {version = "^2.3.1+cu118", source = "pytorch-gpu-src"}
[[tool.poetry.source]]
name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu118"
priority = "explicit"
[build-system]
requires = ["poetry-core"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment