I am trying to recreate the Per-FedAvg function with Pytorch, I am using the MNIST dataset to train the Resnet18 network with seperated dataset as following:
Load data
transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize(0.1307, 0.3081)])
trainset = tv.datasets.MNIST(root='C:/data1', train=True, download=True, transform=transform)
testset = tv.datasets.MNIST(root='C:/data0', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Split data
temp = [len(trainset)//size for i in range(size)]
temp[-1] = len(trainset) - sum(temp[:-1])
Worker_data = torch.utils.data.random_split(trainset, temp)
train_loader = [None]*size
for i in range(size):
train_loader[i] = torch.utils.data.DataLoader(Worker_data[i], batch_size=64, shuffle=True)
Each client have its dataset implement to the train function. However, the error has been occured when computing the hessian, the hessian function like this:
def get_hessian(device, dataseq, model):
loss_fn = torch.nn.CrossEntropyLoss()
index = np.random.randint(0, high=len(dataseq), size=None, dtype=int)
data, target = dataseq[index]
data = data.to(device)
target = target.to(device)
output, _ = model(data)
loss = loss_fn(output, target)
grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)
hessian_params = []
for k in range(len(grads)):
hess_params = torch.zeros_like(grads[k])
for i in range(grads[k].size(0)):
if len(grads[k].size()) == 2:
for j in range(grads[k].size(1)):
hess_params[i, j] = torch.autograd.grad(grads[k][i][j], model.parameters(), retain_graph=True)[k][i, j]
else:
hess_params[i] = torch.autograd.grad(grads[k][i], model.parameters(), retain_graph=True)[k][i]
hessian_params.append(hess_params)
return hessian_params
The error has been raised on the torch.autograd.grad as following:
grad can be implicitly created only for scalar outputs
File "E:\FYP\Per_FedAvg_test.py", line 115, in get_hessian
hess_params[i] = torch.autograd.grad(grads[k][i], model.parameters(), retain_graph=True)[k][i]
File "E:\FYP\Per_FedAvg_test.py", line 57, in train
hessian_params = get_hessian(device, data, original_model)
File "E:\FYP\Per_FedAvg_test.py", line 152, in main
out_weight = train(rank, init_model, lr, train_loader[rank], device, local_epoch, alpha, beta)
File "E:\FYP\Per_FedAvg_test.py", line 177, in <module>
main()
What is the problem?
By the way, I am using Pytorch 1.8 as this is my project limitation.