From 4522d5d65af3d548395b91b3ac98ab3ca5f6e730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Baris=20B=C3=BCy=C3=BCktas?= <baris.bueyuektas@tu-berlin.de> Date: Fri, 13 Jan 2023 14:46:46 +0000 Subject: [PATCH] Update evaluation.py --- evaluation.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 evaluation.py diff --git a/evaluation.py b/evaluation.py new file mode 100644 index 0000000..9979f49 --- /dev/null +++ b/evaluation.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + +import torch.nn as nn +import torch +from pytorch_datasets import Ben19Dataset +from pytorch_datasets1 import Ben19Dataset1 +from pytorch_utils import start_cuda, get_classification_report, print_micro_macro +from sklearn.metrics import classification_report +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +import numpy as np +from model import ANN +from model import ANN_S2 +# In[ ]: + +def change_sizes(labels): + new_labels=np.zeros((len(labels[0]),19)) + for i in range(len(labels[0])): #128 + + for j in range(len(labels)): #19 + new_labels[i,j] = int(labels[j][i]) + + return torch.from_numpy(new_labels) + + +def test(ann,ann1,ann2): + ann.eval() + lmdb_path = '/faststorage/BigEarthNet_S1_S2/BEN_S1_S2.lmdb' + batch_size=256 + num_workers=4 + y_true = [] + predicted_probs = [] + + + + csv_val_path='/data/all_test.csv' + val_set = Ben19Dataset(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default') + val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True) + + csv_val_path='/data/all_test.csv' + val_set1 = Ben19Dataset1(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default') + val_loader1 = DataLoader(val_set1, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True) + + y_true = [] + s1 = torch.ones(1,2048).cuda() + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(val_loader, desc="test")): + + data = batch['s1_bands'].cuda() + labels = batch['label'] + + labels=labels.detach().numpy() + y_true += list(labels) + logits,x = model1(data) + + s1 = torch.cat((s1,x),0) + + s2 = torch.ones(1,2048).cuda() + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(val_loader1, desc="test")): + + data = batch['data'].cuda() + + labels = batch['label'] + + + labels=change_sizes(labels) + labels=labels.detach().numpy() + + logits,x1 = model2(data) + s2 = torch.cat((s2,x1),0) + + s1 = torch.cat((s1,s2),1) + s1 = s1[1:,:] + print(s1.shape) + new_classifier = nn.Sequential(*list(ann1.children())) + print(new_classifier) + predicted_probs=[] + for i in range(s1.shape[0]): + + probs = torch.sigmoid( torch.matmul( s1[i,:] , torch.transpose( ann.state_dict()['FC.weight'],0,1) ) + ann.state_dict()['FC.bias'] ).cpu().numpy() + probs = probs.reshape(1,-1) + + predicted_probs += list(probs) + + + predicted_probs = np.asarray(predicted_probs) + y_predicted = (predicted_probs >= 0.5).astype(np.float32) + + + y_true = np.asarray(y_true) + + print(len(y_true),len(y_predicted)) + print( classification_report(y_true, y_predicted) ) + report = get_classification_report(y_true, y_predicted, predicted_probs, 'aaaa') + print_micro_macro(report) + + +# In[ ]: + + +net = torch.load('/data/scenario2-1.pth') +model = ANN(name='server').to("cuda") +model.load_state_dict(net) + + +net1 = torch.load('/data/scenario2-2.pth') +model1 = ANN(name='server1').to("cuda") +model1.load_state_dict(net1) + +net2 = torch.load('/data/scenario2-3.pth') +model2 = ANN_S2(name='server2').to("cuda") +model2.load_state_dict(net2) + + +test(model,model1,model2) + -- GitLab