Skip to content
Snippets Groups Projects
Commit ec7f9ceb authored by Leonard Wayne Hackel's avatar Leonard Wayne Hackel
Browse files

fixing seed in name, fixing none as warmup

parent ec5e762c
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ class BENv2ImageEncoder(pl.LightningModule, PyTorchModelHubMixin):
):
super().__init__()
self.lr = lr
self.warmup = None if warmup < 0 else warmup
self.warmup = None if warmup is None or warmup < 0 else warmup
self.config = config
assert config.network_type == ILMType.IMAGE_CLASSIFICATION
assert config.classes == 19
......
......@@ -16,7 +16,7 @@ def get_arch_version_bandconfig(model_name: str, config: ILMConfiguration):
architecture = model_name.split("/")[-1].split("-")[1]
assert architecture == config.timm_model_name, f"Model name {architecture} does not match config {config.timm_model_name}"
version = model_name.split("/")[-1].split("-")[-1]
bandconfig = model_name.split("/")[-1].split("-")[3]
bandconfig = model_name.split("/")[-1].split("-")[2]
if bandconfig == "s2":
assert config.channels == 10, f"Bandconfig {bandconfig} does not match config {config.channels}"
elif bandconfig == "s1":
......@@ -126,7 +126,7 @@ def train_new_model(
def main(
model_name: str = typer.Option("BIFOLD-BigEarthNetv2-0/BENv2-resnet50-42-s2-v0.1.1", help="Model name"),
model_name: str = typer.Option("BIFOLD-BigEarthNetv2-0/BENv2-resnet50-all-v0.1.1", help="Model name"),
seed: int = typer.Option(42, help="Random seed"),
lr: float = typer.Option(0.001, help="Learning rate"),
epochs: int = typer.Option(100, help="Number of epochs"),
......@@ -179,7 +179,7 @@ def main(
print(f"New model improved the compare metric by {new_metric - compare_metric:.4f}")
print("=== Uploading model to Huggingface Hub ===")
architecture, version, bandconfig = get_arch_version_bandconfig(model_name, config)
new_model_name = f"BENv2-{architecture}-{seed}-{bandconfig}-{version}"
new_model_name = f"BENv2-{architecture}-{bandconfig}-{version}"
assert model_name == new_model_name, f"Model name {model_name} does not match new model name {new_model_name}"
if upload_hf_entity:
print(f"Uploading model as {model_name}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment