Skip to content
Snippets Groups Projects
Commit bb0e63df authored by Gencer Sumbul's avatar Gencer Sumbul
Browse files

fix

parent db2caeb7
No related branches found
No related tags found
No related merge requests found
general:
output_path: "/media/storagecube/gencer/semanticTripletSampling/dumps/das-its-ucm-efficientnet"
output_path: "/media/storagecube/gencer/semanticTripletSampling/dumps/das-its-ucm-vgg"
version: "1"
clear_old_output: false
......@@ -9,7 +9,7 @@ general:
archive_data_size: 420
model_args:
model_arch: "efficientnet"
model_arch: "vgg"
feature_size: 256
train:
......
......@@ -46,15 +46,10 @@ class PredefinedModel(BaseModel):
return self.final_layer(x)
class EfficientNetB0(PredefinedModel):
class VGG16(PredefinedModel):
def __init__(self, **kwargs):
super(EfficientNetB0, self).__init__(
tf.keras.applications.EfficientNet(
1.0,
1.0,
224,
0.2,
model_name='efficientnetb0'), **kwargs)
super(VGG16, self).__init__(
tf.keras.applications.VGG16, **kwargs)
class ResNet50v2(PredefinedModel):
def __init__(self, **kwargs):
......@@ -126,8 +121,8 @@ class TripletModel(BaseModel):
self.internal_model = DenseNet()
elif model_arch == 'scnn':
self.internal_model = SCNN()
elif model_arch == 'efficientnet':
self.internal_model = EfficientNetB0()
elif model_arch == 'vgg':
self.internal_model = VGG16()
else:
raise ValueError(f'Invalid model architecture "{model_arch}".')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment