I am working on a federated learning project. I write a code to stimulate the process of federated learning. However, after global aggregation in every iteration, the test accuracy of the global model will decrease a lot and remain unchanged in the following iteration. The algorithm of aggregation I used is FedAvg. And I have tried to split my code into different units to find out the problem.For the local training, the selected clients train for 3 epochs. In this experiment, all five clients will be selected for training and aggregation, the model I used for the local is the vgg16 forked from torchvision
, the dataset is the MNIST and split in i.i.d manner for each client:
for id, net_id in enumerate(selected): logging.info("Training Selected Device %s." % (str(net_id))) result = Userlists[net_id].train(hparams['n_local_epochs']) logging.info('>> Local model %d: local accuracy: %f in round %d\n' % (id, result['local_test_acc'], step+1))
Before the aggregation of the local model, I test the accuracy of the local model using the test data of the global server,
tesc, conf = misc.compute_accuracy(Userlists[2].model, test_dl_global, get_confusion_matrix=True, device=hparams['device'])print(tesc)> 0.2478966346153846tesc, conf = misc.compute_accuracy(Userlists[3].model, test_dl_global, get_confusion_matrix=True, device=hparams['device'])print(tesc)> 0.14413060897435898tesc, conf = misc.compute_accuracy(Userlists[4].model, test_dl_global, get_confusion_matrix=True, device=hparams['device'])print(tesc)> 0.17387820512820512
And I used the aggregation code below to aggregate the weights of selected clients:
total_sum = 0.0for client_idx in selected: total_sum += Userlists[client_idx].data_len global_para = global_model.state_dict()client_weights = [torch.tensor( Userlists[client_idx].data_len/total_sum, device=hparams['device']) for client_idx in selected]with torch.no_grad(): for order, idx in enumerate(selected): logging.info(f"For Client {idx}") net_para = Userlists[idx].model.state_dict() if order == 0: for key in net_para.keys(): global_para[key] = net_para[key] * client_weights[order] else: for key in net_para.keys(): global_para[key] += net_para[key] * client_weights[order]global_model.load_state_dict(global_para)tesc, conf = misc.compute_accuracy(global_model, train_dl_global, get_confusion_matrix=True, device=hparams['device'])
And the global test accuracy decreases and remains the same as
> 0.11236666666666667
Although I have try to increase the epochs of local training as the local accuracy increase to 40%
, the global accuracy still fall into the same value as before. Is there any wrong place in my code for aggregation?
The test accuracy should remain at the same level as the local accuracy.