I'm training a neural network (doesn't matter which one) on CIFAR-10 dataset. I'm using Federated Learning:
- I have 10 models, each model having access to its own part of the dataset. At every time step, each model makes a step using its own data, and then the global model is an average of the model (this version is based on this, but I tried a lot of options):
def server_aggregate(server_model, client_models):
global_dict = server_model.state_dict()
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
server_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(server_model.state_dict())
- To be specific, each machine only has access to a data corresponding to a single class. I.e. machine
0
has only samples corresponding to class0
, etc. I'm doing it the following way:
def split_into_classes(full_ds, batch_size, num_classes=10):
class2indices = [[] for _ in range(num_classes)]
for i, y in enumerate(full_ds.targets):
class2indices[y].append(i)
datasets = [torch.utils.data.Subset(full_ds, indices) for indices in class2indices]
return [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in datasets]
Problem. During training, I can see that my federated training loss decreases. However, I never see my test loss/accuracy improve (acc is always around 10%). Moreover, when I check accuracy on train/test datasets:
- For the federated dataset, the accuracy improves.
- For the testing dataset, the accuracy doesn't improve.
- (Most surprising) for the training dataset, the accuracy doesn't improve. Note that this dataset is essentially the same as federated dataset, but not split into classes. The checking code is the following:
def epoch_summary(model, fed_loaders, true_train_loader, test_loader, frac):
with torch.no_grad():
train_len = 0
train_loss, train_acc = 0, 0
for train_loader in fed_loaders:
cur_loss, cur_acc, cur_len = true_results(model, train_loader, frac)
train_loss += cur_len * cur_loss
train_acc += cur_len * cur_acc
train_len += cur_len
train_loss /= train_len
train_acc /= train_len
true_train_loss, true_train_acc, true_train_len = true_results(model, true_train_loader, frac)
test_loss, test_acc, test_len = true_results(model, test_loader, frac)
print("TrainLoss: {:.4f} TrainAcc: {:.2f} TrueLoss: {:.4f} TrueAcc: {:.2f} TestLoss: {:.4f} TestAcc: {:.2f}".format(
train_loss, train_acc, true_train_loss, true_train_acc, test_loss, test_acc
), flush=True)
The full code can be found here. Things which don't seem to matter:
- Model. I got the same problem for Resnet models and for some other models.
- How I aggregate the models. I tried using
state_dict
or directly manipulatemodel.parameters()
, no effect. - How I learn the models. I tried using
optim.SGD
or directly updateparam.data -= learning_rate * param.grad
, no effect. - Computational graph. I've tried adding
.detach().clone()
andwith torch.no_grad()
into all possible places, no effect.
So I'm suspecting that the problem is somehow with the federated data itself (especially given strange accuracy results). What can be a problem?