1

I'm training a neural network (doesn't matter which one) on CIFAR-10 dataset. I'm using Federated Learning:

  • I have 10 models, each model having access to its own part of the dataset. At every time step, each model makes a step using its own data, and then the global model is an average of the model (this version is based on this, but I tried a lot of options):
def server_aggregate(server_model, client_models):
    global_dict = server_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    server_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(server_model.state_dict())
  • To be specific, each machine only has access to a data corresponding to a single class. I.e. machine 0 has only samples corresponding to class 0, etc. I'm doing it the following way:
def split_into_classes(full_ds, batch_size, num_classes=10):
  class2indices = [[] for _ in range(num_classes)]
  for i, y in enumerate(full_ds.targets):
    class2indices[y].append(i)

  datasets = [torch.utils.data.Subset(full_ds, indices) for indices in class2indices]
  return [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in datasets]

Problem. During training, I can see that my federated training loss decreases. However, I never see my test loss/accuracy improve (acc is always around 10%). Moreover, when I check accuracy on train/test datasets:

  • For the federated dataset, the accuracy improves.
  • For the testing dataset, the accuracy doesn't improve.
  • (Most surprising) for the training dataset, the accuracy doesn't improve. Note that this dataset is essentially the same as federated dataset, but not split into classes. The checking code is the following:
def epoch_summary(model, fed_loaders, true_train_loader, test_loader, frac):
  with torch.no_grad():
    train_len = 0
    train_loss, train_acc = 0, 0
    for train_loader in fed_loaders:
      cur_loss, cur_acc, cur_len = true_results(model, train_loader, frac)
      train_loss += cur_len * cur_loss
      train_acc += cur_len * cur_acc
      train_len += cur_len

    train_loss /= train_len
    train_acc /= train_len

    true_train_loss, true_train_acc, true_train_len = true_results(model, true_train_loader, frac)
    test_loss, test_acc, test_len = true_results(model, test_loader, frac)

  print("TrainLoss: {:.4f} TrainAcc: {:.2f} TrueLoss: {:.4f} TrueAcc: {:.2f} TestLoss: {:.4f} TestAcc: {:.2f}".format(
        train_loss, train_acc, true_train_loss, true_train_acc, test_loss, test_acc
        ), flush=True)

The full code can be found here. Things which don't seem to matter:

  • Model. I got the same problem for Resnet models and for some other models.
  • How I aggregate the models. I tried using state_dict or directly manipulate model.parameters(), no effect.
  • How I learn the models. I tried using optim.SGD or directly update param.data -= learning_rate * param.grad, no effect.
  • Computational graph. I've tried adding .detach().clone() and with torch.no_grad() into all possible places, no effect.

So I'm suspecting that the problem is somehow with the federated data itself (especially given strange accuracy results). What can be a problem?

Dmitry
  • 240
  • 1
  • 12
  • *"During training, I can see that my training loss decreases. However, I never see my test loss/accuracy improve"* is that the federated model on the whole train/test datasets? – Ivan Jan 31 '21 at 09:49
  • I'm not sure I understand the question, but yes. Note that I estimate evaluate model on the full training data (so it should be perfect, but it's not). – Dmitry Jan 31 '21 at 15:31
  • @Ivan, The experiments show that the problem was batch normalization (when I use dense model or replace BN layers with Identity layer, it works). Looks like you shouldn't average it like this. Do you know what's the proper way to handle them during model averaging? – Dmitry Jan 31 '21 at 18:06
  • *"machine 0 has only samples corresponding to class 0"* I don't understand how you would go about training a model for *1-class* classification task if machine `0` only has data points of class `0`. This doesn't make sense to me. – Ivan Jan 31 '21 at 21:11
  • @Ivan, on the high level: at one step, machine 0 learns about class 0, machine 1 - about class 1, etc. When you average these models, you somewhat learn information about all classes. Check some "federated learning" papers - it's essentially a distributed learning with non-iid data and infrequent communication, and the process is shown to converge. – Dmitry Jan 31 '21 at 21:37
  • @Ivan, I forgot to say that after averaging, the resulting model replaces local models – Dmitry Jan 31 '21 at 23:14
  • I see, that makes more sense. Do you have a particular paper in mind? – Ivan Jan 31 '21 at 23:21
  • 1
    I think [Local SGD Converges Fast and Communicates Little](https://arxiv.org/pdf/1805.09767.pdf) covers both points I've mentioned. – Dmitry Feb 01 '21 at 03:48

1 Answers1

1

10% on CIFAR-10 is basically random - your model outputs labels at random and gets 10%.

I think the problem lies in your "federated training" strategy: you cannot expect your sub-models to learn anything meaningful when all they see is a single label. This is why training data is shuffled.
Think of it: if each of your sub models learns all weights to be zero apart from the bias vector of the last classification layer that has 1 in the entry corresponding to the class this sub-model sees - the training of each sub model is perfect (it gets it right for all training samples it sees), but the averaged model is meaningless.

Shai
  • 111,146
  • 38
  • 238
  • 371
  • With a sufficiently small step size, federated strategy is guaranteed to converge (it'll find a point where gradient on the training data is 0), regardless of data distribution. In my second link I do the following: at every step I select a batch for each machine, train them on their batches, and then average the models. It's completely equivalent to selecting a batch of 10x size where each class constitutes 10% of the batch. The effect is the same. – Dmitry Jan 31 '21 at 15:30
  • 1
    [Considering that *"each machine only has access [...] to a single class"*]. *"It's completely equivalent to selecting a batch"* that is not true. Training those models separately then averaging is definitely not the same as training on the whole class set with a single model. Why would you assume that? – Ivan Jan 31 '21 at 16:00
  • @Ivan. (please tag me so that I can see the reply) Let `x` be current model parameters (each machine has the same parmeters). For each machine, the update is `x_i <- x - lr * grad_i(x)`. Note that I average the models after each step, and so the new global model is `avg(x_i) = x - lr * avg(grad_i(x))`. Each `grad_i` is an average grad of the loss on samples in the batch of the `i`th model. Since all batch sizes are equal, `avg(grad_i)` is an average grad of the loss on all samples participating in the step. In other words, it's equal to [see my first comment] – Dmitry Jan 31 '21 at 16:43
  • 1
    @Shai, The experiments show that the problem was batch normalization (when I use dense model or replace BN layers with Identity layer, it works). Looks like you shouldn't average it like this. Do you know what's the proper way to handle them during model averaging? – Dmitry Jan 31 '21 at 18:06
  • @Dmitry it all seems to stem from the non-random way you split the data. Can you try random splits? Why do you insist on this split? – Shai Jan 31 '21 at 18:47
  • @Shai, because we are writing a paper about federated learning convergence when the data distribution is non-iid. – Dmitry Jan 31 '21 at 21:29