Skip to content
Snippets Groups Projects
Commit 4522d5d6 authored by Baris Büyüktas's avatar Baris Büyüktas
Browse files

Update evaluation.py

parent 71f49049
No related branches found
No related tags found
1 merge request!1Update README.md, data/austria_train.csv, data/switzerland_train.csv,...
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
import torch.nn as nn
import torch
from pytorch_datasets import Ben19Dataset
from pytorch_datasets1 import Ben19Dataset1
from pytorch_utils import start_cuda, get_classification_report, print_micro_macro
from sklearn.metrics import classification_report
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from model import ANN
from model import ANN_S2
# In[ ]:
def change_sizes(labels):
new_labels=np.zeros((len(labels[0]),19))
for i in range(len(labels[0])): #128
for j in range(len(labels)): #19
new_labels[i,j] = int(labels[j][i])
return torch.from_numpy(new_labels)
def test(ann,ann1,ann2):
ann.eval()
lmdb_path = '/faststorage/BigEarthNet_S1_S2/BEN_S1_S2.lmdb'
batch_size=256
num_workers=4
y_true = []
predicted_probs = []
csv_val_path='/data/all_test.csv'
val_set = Ben19Dataset(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default')
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True)
csv_val_path='/data/all_test.csv'
val_set1 = Ben19Dataset1(lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform='default')
val_loader1 = DataLoader(val_set1, batch_size=batch_size, num_workers=num_workers,shuffle=False, pin_memory=True)
y_true = []
s1 = torch.ones(1,2048).cuda()
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(val_loader, desc="test")):
data = batch['s1_bands'].cuda()
labels = batch['label']
labels=labels.detach().numpy()
y_true += list(labels)
logits,x = model1(data)
s1 = torch.cat((s1,x),0)
s2 = torch.ones(1,2048).cuda()
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(val_loader1, desc="test")):
data = batch['data'].cuda()
labels = batch['label']
labels=change_sizes(labels)
labels=labels.detach().numpy()
logits,x1 = model2(data)
s2 = torch.cat((s2,x1),0)
s1 = torch.cat((s1,s2),1)
s1 = s1[1:,:]
print(s1.shape)
new_classifier = nn.Sequential(*list(ann1.children()))
print(new_classifier)
predicted_probs=[]
for i in range(s1.shape[0]):
probs = torch.sigmoid( torch.matmul( s1[i,:] , torch.transpose( ann.state_dict()['FC.weight'],0,1) ) + ann.state_dict()['FC.bias'] ).cpu().numpy()
probs = probs.reshape(1,-1)
predicted_probs += list(probs)
predicted_probs = np.asarray(predicted_probs)
y_predicted = (predicted_probs >= 0.5).astype(np.float32)
y_true = np.asarray(y_true)
print(len(y_true),len(y_predicted))
print( classification_report(y_true, y_predicted) )
report = get_classification_report(y_true, y_predicted, predicted_probs, 'aaaa')
print_micro_macro(report)
# In[ ]:
net = torch.load('/data/scenario2-1.pth')
model = ANN(name='server').to("cuda")
model.load_state_dict(net)
net1 = torch.load('/data/scenario2-2.pth')
model1 = ANN(name='server1').to("cuda")
model1.load_state_dict(net1)
net2 = torch.load('/data/scenario2-3.pth')
model2 = ANN_S2(name='server2').to("cuda")
model2.load_state_dict(net2)
test(model,model1,model2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment