0

I need to apply ZCA whitening in PyTorch. I think I have found a way this can be done by using transforms.LinearTransformation and I have found a test in the PyTorch repo which gives some insight into how this is done (see final code block or link below)

https://github.com/pytorch/vision/blob/master/test/test_transforms.py

I am struggling to work out how I apply something like this myself.

Currently I have transforms along the lines of:

    transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0,
                         np.array([63.0, 62.1, 66.7]) / 255.0),
])

The documents say they way to use LinearTransformation is as follows:

torchvision.transforms.LinearTransformation(transformation_matrix, mean_vector) 

whitening transformation: Suppose X is a column vector zero-centered data. Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), perform SVD on this matrix and pass it as transformation_matrix.

I can see from the tests I linked above and copied below that they are using torch.mm to calculate what they call a principal_components:

def test_linear_transformation(self):
    num_samples = 1000
    x = torch.randn(num_samples, 3, 10, 10)
    flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
    # compute principal components
    sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
    u, s, _ = np.linalg.svd(sigma.numpy())
    zca_epsilon = 1e-10  # avoid division by 0
    d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
    u = torch.Tensor(u)
    principal_components = torch.mm(torch.mm(u, d), u.t())
    mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
    # initialize whitening matrix
    whitening = transforms.LinearTransformation(principal_components, mean_vector)
    # estimate covariance and mean using weak law of large number
    num_features = flat_x.size(1)
    cov = 0.0
    mean = 0.0
    for i in x:
        xwhite = whitening(i)
        xwhite = xwhite.view(1, -1).numpy()
        cov += np.dot(xwhite, xwhite.T) / num_features
        mean += np.sum(xwhite) / num_features
    # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
    assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
    assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"

    # Checking if LinearTransformation can be printed as string
    whitening.__repr__()

How do I apply something like this? do I use it where I define my transforms or apply it in my training loop where I am iterating over my training loop?

Thanks in advance

nado
  • 83
  • 2
  • 9

1 Answers1

1

ZCA whitening is typically a preprocessing step, like center-reduction, which basically aims at making your data more NN-friendly (additional info below). As such, it is supposed to be applied once, right before training.

So right before you starts training your model with a given dataset X, compute the whitened dataset Z, which is simply the multiplication of X with the ZCA matrix W_zca that you can learn to compute here. Then train your model on the whitened dataset. Finally, you should have something that looks like this

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        # Feel free to use something more useful than a simple linear layer
        self._network = torch.nn.Linear(...)
        # Do your stuff
        ...

    def fit(self, inputs, labels):
    """ Trains the model to predict the right label for a given input """
        # Compute the whitening matrix and inputs
        self._zca_mat = compute_zca(inputs)
        whitened_inputs = torch.mm(self._zca_mat, inputs)

        # Apply training on the whitened data
        outputs = self._network(whitened_inputs)
        loss = torch.nn.MSEloss()(outputs, labels)
        loss.backward()
        optimizer.step()

     def forward(self, input):
         # You always need to apply the zca transform before forwarding, 
         # because your network has been trained with whitened data
         whitened_input = torch.mm(self._zca_mat, input)
         predicted_label = self._network.forward(whitened_input)
         return predicted_label

Additional info

Whitening your data means decorrelating its dimensions so that the correlation matrix of the whitened data is the identity matrix. It is a rotation-scaling operation (thus linear), and there are actually an infinity of possible ZCA transforms. To understand the maths behind ZCA, read this

trialNerror
  • 3,255
  • 7
  • 18