Skip to content
Snippets Groups Projects
client.py 5.93 KiB

import copy
from itertools import chain
from pytorch_utils import start_cuda, get_classification_report, print_micro_macro
from sklearn.metrics import classification_report
from model import ANN
from model import ANN_S2
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import mean_absolute_error, mean_squared_error
from torch import nn
from tqdm import tqdm
from pytorch_metric_learning.losses import NTXentLoss
import torch.nn.functional as F
from get_data import nn_seq_wind

criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
loss_con = NTXentLoss(temperature=0.1)
cos = nn.CosineSimilarity(dim=0, eps=1e-6)


def change_sizes(labels):
    new_labels=np.zeros((len(labels[0]),19))
    for i in range(len(labels[0])): 
       
       for j in range(len(labels)): 
          new_labels[i,j] =  int(labels[j][i])
       
    return torch.from_numpy(new_labels)
                 

def get_val_loss(args, model, Val):
    model.eval()
    loss_function = nn.MSELoss().to(args.device)
    val_loss = []
    for (seq, label) in Val:
        with torch.no_grad():
            seq = seq.to(args.device)
            label = label.to(args.device)
            y_pred = model(seq)
            loss = loss_function(y_pred, label)
            val_loss.append(loss.item())

    return np.mean(val_loss)

def manipulation(model,layer,parameters):
   for name, param in model.named_parameters():
     if param.requires_grad:
        if name == layer:
            param.data = parameters
     
        
       

def train(args, model, server,k,s1,s2):
    
    
    global_model = copy.deepcopy(server)
    if(k<=2):
        model_random = ANN_S2(args=args, name='server6').to(args.device)
        manipulation(model_random,'conv1.weight',s2.state_dict()['conv1.weight'])
        manipulation(model_random,'encoder.0.weight',s2.state_dict()['encoder.0.weight'])
        for j,key in enumerate( model_random.state_dict() ):
          
            if(j>1 and j<=318):
                
                manipulation(model_random,key,s2.state_dict()[key])
        

    else:     
        model_random = ANN(args=args, name='server7').to(args.device)
        manipulation(model_random,'conv1.weight',s1.state_dict()['conv1.weight'])
        manipulation(model_random,'encoder.0.weight',s1.state_dict()['encoder.0.weight'])
        for j,key in enumerate( model_random.state_dict() ):
           if(j>1 and j<=318):
                
                manipulation(model_random,key,s1.state_dict()[key])
    
       
    for j,key in enumerate( model_random.state_dict() ):
            if(j>318):
                
                
                manipulation(model_random,key,server.state_dict()[key])
            
    model=copy.deepcopy(model_random)
    
    model.train()
    Dtr, Dte = nn_seq_wind(k)
    model.len = len(Dtr)
    global_model = copy.deepcopy(server)
    lr = args.lr
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                     weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9, weight_decay=args.weight_decay)
    stepLR = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    # training
    best_model = None
    print('training...')
  
    for epoch in tqdm(range(args.E)):
        train_loss = []
        running_loss = 0.0
        for idx,batch in enumerate(tqdm(Dtr,desc="training")):
            
            if(k<=2):
                seq,labels,index=batch['data'], batch['label'], batch['index']
                seq = seq.cuda()
                label=np.copy(labels)
                label=change_sizes(label)
                
                s2_new = nn.Sequential(list(model.children())[1][0:1])
               
                features_s2 = s2_new(seq)
                s1_new = nn.Sequential(*list(s1.children())[1][1:])
               
                features_s2 = s1_new(features_s2)
                features_s2.detach()
               
            else:
                seq,labels,index=batch['s1_bands'], batch['label'], batch['index']
                seq = seq.cuda()
                label=np.copy(labels)
                label = torch.from_numpy(label)
                
                
                s1_new = nn.Sequential(list(model.children())[1][0:1])
                features_s1 = s1_new(seq)
                s2_new = nn.Sequential(*list(s2.children())[1][1:])
                features_s1 = s2_new(features_s1)
                features_s1.detach()
            
            label = label.cuda()
            model_new = nn.Sequential(list(model.children())[1])
            features= model_new(seq)
            y_pred = model(seq)

            if(k<=2):

                mm_labels = torch.arange(seq.shape[0])
                mm_labels1 = torch.arange(seq.shape[0])
                features_emd = torch.cat(   (   features[:,:,0,0],features_s2[:,:,0,0])  ,0   )
                labels_emd = torch.cat((mm_labels,mm_labels1),0)
              
                loss_con1 =  loss_con(features_emd,labels_emd)  
               
            else:
                
                mm_labels = torch.arange(seq.shape[0])
                mm_labels1 = torch.arange(seq.shape[0])
                loss_con1 = torch.tensor(0) 

                features_emd = torch.cat(   (   features[:,:,0,0],features_s1[:,:,0,0])  ,0   )
                labels_emd = torch.cat((mm_labels,mm_labels1),0)
              
                loss_con1 =  loss_con(features_emd,labels_emd)
               
               
            
            optimizer.zero_grad()           
            loss=criterion(y_pred,label)           
            loss = loss + 0.01 * loss_con1
            loss.backward()
           
            running_loss += loss.item()
            optimizer.step()
        stepLR.step()

        best_model = copy.deepcopy(model)
        model.train()
        
      
        print(running_loss)
    return best_model