Skip to content
Snippets Groups Projects
Commit 4674bb1f authored by Leonard Wayne Hackel's avatar Leonard Wayne Hackel
Browse files

init with env-v1 and pretraining script

parents
No related branches found
No related tags found
No related merge requests found
# copied from pl_bolts but modified due to an import error in pl_bolts
# see from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
import math
import warnings
from typing import List
from torch import nn
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class LinearWarmupCosineAnnealingLR(_LRScheduler):
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
after each iteration as calling it after each epoch will keep the starting lr at
warmup_start_lr for the first epoch which is 0 in most cases.
.. warning::
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
train and validation methods.
Example:
>>> layer = nn.Linear(10, 1)
>>> optimizer = Adam(layer.parameters(), lr=0.02)
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
>>> #
>>> # the default case
>>> for epoch in range(40):
... # train(...)
... # validate(...)
... scheduler.step()
>>> #
>>> # passing epoch param case
>>> for epoch in range(40):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
"""
def __init__(
self,
optimizer: Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.0,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_epochs (int): Maximum number of iterations for linear warmup
max_epochs (int): Maximum number of iterations
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
"""Compute learning rate using chainable form of the scheduler."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"]
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (
1
+ math.cos(
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
)
)
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> List[float]:
"""Called when epoch is passed as a param to the `step` function of the scheduler."""
if self.last_epoch < self.warmup_epochs:
return [
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr in self.base_lrs
]
return [
self.eta_min
+ 0.5
* (base_lr - self.eta_min)
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
for base_lr in self.base_lrs
]
\ No newline at end of file
name: lit4rsvqa
channels:
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- appdirs=1.4.4=pyh9f0ad1d_0
- brotlipy=0.7.0=py310h5764c6d_1005
- bzip2=1.0.8=h7f98852_4
- ca-certificates=2022.12.7=ha878542_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py310h255011f_3
- charset-normalizer=3.1.0=pyhd8ed1ab_0
- click=8.1.3=unix_pyhd8ed1ab_2
- cryptography=40.0.2=py310h34c0648_0
- docker-pycreds=0.4.0=py_0
- gitdb=4.0.10=pyhd8ed1ab_0
- gitpython=3.1.31=pyhd8ed1ab_0
- idna=3.4=pyhd8ed1ab_0
- ld_impl_linux-64=2.40=h41732ed_0
- libffi=3.4.2=h7f98852_5
- libgcc-ng=12.2.0=h65d4601_19
- libgomp=12.2.0=h65d4601_19
- libnsl=2.0.0=h7f98852_0
- libprotobuf=3.21.12=h3eb15da_0
- libsqlite=3.40.0=h753d276_1
- libstdcxx-ng=12.2.0=h46fd767_19
- libuuid=2.38.1=h0b41bf4_0
- libzlib=1.2.13=h166bdaf_4
- ncurses=6.3=h27087fc_1
- openssl=3.1.0=hd590300_2
- pathtools=0.1.2=py_1
- pip=23.1.1=pyhd8ed1ab_0
- protobuf=4.21.12=py310heca2aa9_0
- psutil=5.9.5=py310h1fa729e_0
- pycparser=2.21=pyhd8ed1ab_0
- pyopenssl=23.1.1=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.10=he550d4f_0_cpython
- python_abi=3.10=3_cp310
- pyyaml=6.0=py310h5764c6d_5
- readline=8.2=h8228510_1
- requests=2.28.2=pyhd8ed1ab_1
- sentry-sdk=1.21.0=pyhd8ed1ab_0
- setproctitle=1.3.2=py310h5764c6d_1
- setuptools=67.7.2=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- smmap=3.0.5=pyh44b312d_0
- tk=8.6.12=h27826a3_0
- typing_extensions=4.5.0=pyha770c72_0
- tzdata=2023c=h71feb2d_0
- urllib3=1.26.15=pyhd8ed1ab_0
- wandb=0.15.0=pyhd8ed1ab_0
- wheel=0.40.0=pyhd8ed1ab_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- pip:
- attrs==23.1.0
- colorama==0.4.6
- configilm==0.2.0
- cycler==0.11.0
- fonttools==4.39.3
- kiwisolver==1.4.4
- packaging==23.1
- pillow==9.5.0
- pygments==2.15.1
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pytz==2023.3
- scipy==1.10.1
prefix: /home/lhackel/mambaforge/envs/lit4rsvqa
# import packages
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import optim
from configilm import ConfigILM
from configilm.ConfigILM import ILMConfiguration, ILMType
from configilm.extra.BEN_DataModule_LMDB_Encoder import BENDataModule
from configilm.extra.BEN_lmdb_utils import resolve_ben_data_dir
import typer
import os
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from torchmetrics.classification import MultilabelAveragePrecision
from torchmetrics.classification import MultilabelF1Score
from LinWarCosAnLR import LinearWarmupCosineAnnealingLR
__author__ = "Leonard Hackel - BIFOLD/RSiM TU Berlin"
os.environ["WANDB_START_METHOD"] = "thread"
wandb_api_key = os.environ["WANDB_API_KEY"]
class LitVisionEncoder(pl.LightningModule):
"""
Wrapper around a pytorch module, allowing this module to be used in automatic
training with pytorch lightning.
Among other things, the wrapper allows us to do automatic training and removes the
need to manage data on different devices (e.g. GPU and CPU).
"""
def __init__(
self,
config: ConfigILM.ILMConfiguration,
lr: float = 1e-3,
):
super().__init__()
self.lr = lr
self.config = config
self.model = ConfigILM.ConfigILM(config)
def training_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
self.log("train/loss", loss)
return {"loss": loss}
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
# these are steps if interval is set to step
max_intervals = int(self.trainer.max_epochs *
len(self.trainer.datamodule.train_ds) /
self.trainer.datamodule.batch_size)
warmup = 10000 if max_intervals > 10000 else 100 if max_intervals > 100 else 0
print(f"Optimizing for {max_intervals} steps with warmup for {warmup} steps")
lr_scheduler = {
'scheduler': LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=warmup,
max_epochs=max_intervals,
warmup_start_lr=self.lr / 10,
eta_min=self.lr / 10
),
'name': 'learning_rate',
'interval': "step",
'frequency': 1
}
return [optimizer], [lr_scheduler]
def validation_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
return {"loss": loss, "outputs": x_hat, "labels": y}
def validation_epoch_end(self, outputs):
metrics = self.get_metrics(outputs)
self.log("val/loss", metrics["avg_loss"])
self.log("val/f1", metrics["avg_f1_score"])
self.log("val/mAP (Micro)", metrics["map_score"]["micro"])
self.log("val/mAP (Macro)", metrics["map_score"]["macro"])
def test_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
return {"loss": loss, "outputs": x_hat, "labels": y}
def test_epoch_end(self, outputs):
metrics = self.get_metrics(outputs)
self.log("test/loss", metrics["avg_loss"])
self.log("test/f1", metrics["avg_f1_score"])
self.log("test/mAP (Micro)", metrics["map_score"]["micro"])
self.log("test/mAP (Macro)", metrics["map_score"]["macro"])
def forward(self, batch):
# because we are a wrapper, we call the inner function manually
return self.model(batch)
def get_metrics(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
logits = torch.cat([x["outputs"].cpu() for x in outputs], 0)
labels = torch.cat(
[x["labels"].cpu() for x in outputs], 0
) # Tensor of size (#samples x classes)
f1_score = MultilabelF1Score(num_labels=self.config.classes, average=None).to(
logits.device
)(logits, labels)
# calculate AP
ap_micro = MultilabelAveragePrecision(
num_labels=self.config.classes, average="micro"
).to(logits.device)(logits, labels.int())
ap_macro = MultilabelAveragePrecision(
num_labels=self.config.classes, average="macro"
).to(logits.device)(logits, labels.int())
ap_score = {"micro": float(ap_micro), "macro": float(ap_macro)}
avg_f1_score = float(
torch.sum(f1_score) / self.config.classes
) # macro average f1 score
return {
"avg_loss": avg_loss,
"avg_f1_score": avg_f1_score,
"map_score": ap_score,
}
def main(
model_name: str = "mobilevit_s",
lr: float = 1e-3,
epochs: int = 100,
batch_size: int = 32,
seed: int = 42,
data_dir: str = None,
test_run: bool = False
):
if test_run:
max_img_index = 10 * batch_size
epochs = 10
else:
max_img_index = -1
pl.seed_everything(seed, workers=True)
img_size = 120
channels = 10
model_config = ILMConfiguration(
timm_model_name=model_name,
hf_model_name=None,
classes=19,
image_size=img_size,
channels=channels,
network_type=ILMType.VISION_CLASSIFICATION
)
# Key is available by wandb, project name can be chosen at will
wandb.login(key=wandb_api_key)
tags = ["Pretraining", model_name]
if test_run:
tags += ["Test Run"]
wandb_logger = WandbLogger(project=f"LiT4RSVQA",
log_model=True,
tags=tags, # keyword arg directly to wandb.init()
)
monitor = "val/mAP (Micro)"
monitor_str = "mAP_micro"
# checkpointing
checkpoint_callback = ModelCheckpoint(
monitor="val/f1",
dirpath="./checkpoints",
filename=f"{wandb_logger.experiment.name}-{model_name}-seed=" + str(seed) +
"-epoch={epoch:03d}-" + f"{monitor_str}" + "={" + f"{monitor}" +
":.3f}",
auto_insert_metric_name=False,
save_top_k=1,
mode="max",
save_last=True
)
early_stopping_callback = EarlyStopping(monitor=monitor, min_delta=0.00,
patience=25, verbose=False, mode="max")
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(
max_epochs=epochs,
accelerator="auto",
log_every_n_steps=5,
logger=wandb_logger,
check_val_every_n_epoch=5,
callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
)
model = LitVisionEncoder(config=model_config, lr=lr)
dm = BENDataModule(
data_dir=resolve_ben_data_dir(data_dir=data_dir),
img_size=(channels, img_size, img_size),
num_workers_dataloader=4,
batch_size=batch_size,
max_img_idx=max_img_index,
)
trainer.fit(model=model, datamodule=dm)
trainer.test(model=model, datamodule=dm, ckpt_path="best")
wandb.finish()
print("=== Training finished ===")
if __name__ == "__main__":
typer.run(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment