From 4522d5d65af3d548395b91b3ac98ab3ca5f6e730 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Baris=20B=C3=BCy=C3=BCktas?= <baris.bueyuektas@tu-berlin.de>
Date: Fri, 13 Jan 2023 14:46:46 +0000
Subject: [PATCH] Update evaluation.py

---
 evaluation.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 120 insertions(+)
 create mode 100644 evaluation.py

diff --git a/evaluation.py b/evaluation.py
new file mode 100644
index 0000000..9979f49
--- /dev/null
+++ b/evaluation.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# In[ ]:
+
+import torch.nn as nn
+import torch
+from pytorch_datasets import Ben19Dataset
+from pytorch_datasets1 import Ben19Dataset1
+from pytorch_utils import start_cuda, get_classification_report, print_micro_macro
+from sklearn.metrics import classification_report
+from torch.utils.data import Dataset, DataLoader
+from tqdm import tqdm
+import numpy as np
+from model import ANN
+from model import ANN_S2
+# In[ ]:
+
+def change_sizes(labels):
+    new_labels=np.zeros((len(labels[0]),19))
+    for i in range(len(labels[0])): #128
+       
+       for j in range(len(labels)): #19
+          new_labels[i,j] =  int(labels[j][i])
+       
+    return torch.from_numpy(new_labels)
+
+
+def test(ann,ann1,ann2):
+    ann.eval()
+    lmdb_path = '/faststorage/BigEarthNet_S1_S2/BEN_S1_S2.lmdb'
+    batch_size=256
+    num_workers=4
+    y_true = []
+    predicted_probs = []
+    
+    
+    
+    csv_val_path='/data/all_test.csv'
+    val_set = Ben19Dataset(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default')
+    val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True)
+    
+    csv_val_path='/data/all_test.csv'
+    val_set1 = Ben19Dataset1(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default')
+    val_loader1 = DataLoader(val_set1, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True)
+    
+    y_true = []
+    s1 = torch.ones(1,2048).cuda()
+    with torch.no_grad():
+        for batch_idx, batch in enumerate(tqdm(val_loader, desc="test")):
+
+            data = batch['s1_bands'].cuda()
+            labels = batch['label']
+
+            labels=labels.detach().numpy()
+            y_true += list(labels)            
+            logits,x = model1(data)
+
+            s1 = torch.cat((s1,x),0)
+
+    s2 = torch.ones(1,2048).cuda()
+    with torch.no_grad():
+        for batch_idx, batch in enumerate(tqdm(val_loader1, desc="test")):
+          
+            data = batch['data'].cuda()
+       
+            labels = batch['label']
+            
+            
+            labels=change_sizes(labels)
+            labels=labels.detach().numpy()
+
+            logits,x1 = model2(data)
+            s2 = torch.cat((s2,x1),0)
+  
+    s1 = torch.cat((s1,s2),1)
+    s1 = s1[1:,:]
+    print(s1.shape)
+    new_classifier = nn.Sequential(*list(ann1.children()))
+    print(new_classifier)
+    predicted_probs=[]
+    for i in range(s1.shape[0]):
+      
+        probs = torch.sigmoid(     torch.matmul( s1[i,:] , torch.transpose( ann.state_dict()['FC.weight'],0,1)  ) +  ann.state_dict()['FC.bias']    ).cpu().numpy()
+        probs = probs.reshape(1,-1)
+
+        predicted_probs += list(probs)
+    
+    
+    predicted_probs = np.asarray(predicted_probs)
+    y_predicted = (predicted_probs >= 0.5).astype(np.float32)
+ 
+    
+    y_true = np.asarray(y_true)
+    
+    print(len(y_true),len(y_predicted)) 
+    print( classification_report(y_true, y_predicted) )
+    report = get_classification_report(y_true, y_predicted, predicted_probs, 'aaaa')
+    print_micro_macro(report)
+
+
+# In[ ]:
+
+
+net = torch.load('/data/scenario2-1.pth')
+model = ANN(name='server').to("cuda")
+model.load_state_dict(net)
+
+
+net1 = torch.load('/data/scenario2-2.pth')
+model1 = ANN(name='server1').to("cuda")
+model1.load_state_dict(net1)
+
+net2 = torch.load('/data/scenario2-3.pth')
+model2 = ANN_S2(name='server2').to("cuda")
+model2.load_state_dict(net2)
+
+
+test(model,model1,model2)
+
-- 
GitLab