1

I have trained a Variational Autoencoder (VAE) with an additional fully connected layer after the encoder for binary image classification. It is setup using PyTorch Lightning. The encoder / decoder is resnet18 from PyTorch Lightning Bolts repo.

During training, I am printing out the classification scores. I can see the scores of the class 1 images tend towards 1 and the class 0 images tend towards 0 as training progresses. The loss also reduces as expected and reaches a local minimum.

However, the classification scores I am getting when inferencing are completely different from the ones achieved during training, even for the same images e.g. the training set.

It also appears the longer I run training, the higher the scores get when inferencing (even for images that are class = 0). So eventually all images from the training set score >0.98.

Why are the inference results different from training?

UPDATE #1: It appears to give reasonable inference results if I don't set model.eval() and pass a single image. Multiple images don't work well due to this. My understanding of model.eval() was it helps achieve consistent results not this behaviour?

Model

from pl_bolts.models.autoencoders.components import (
    resnet18_encoder,
    resnet18_decoder
)

class VariationalAutoencoder(LightningModule):

...

    self.first_conv: bool = False
    self.maxpool1: bool = False
    self.enc_out_dim: int = 512

    self.encoder = resnet18_encoder(first_conv, maxpool1)
    self.fc_object_identity = nn.Linear(self.enc_out_dim, 1)

...

    def _run_step(self, object_image, negative_object_image):
        # 1. Reconstruction model
        object_image_encoded = self.encoder(object_image)
        mu = self.fc_mu(object_image_encoded)
        log_var = self.fc_var(object_image_encoded)
        p, q, z = self.sample(mu, log_var)

        # 2. Classification model (includes negative observation)
        negative_object_image_encoded = self.encoder(negative_object_image)
        object_image_classification_score = torch.sigmoid(self.fc_object_identity(object_image_encoded))
        negative_object_classification_score = torch.sigmoid(self.fc_object_identity(negative_object_image_encoded))

        return z, self.decoder(z), p, q, object_image_classification_score, negative_object_classification_score

    def forward(self, x):
        x_encoded = self.encoder(x)
        mu = self.fc_mu(x_encoded)
        log_var = self.fc_var(x_encoded)
        p, q, z = self.sample(mu, log_var)

        x_classification_score = torch.sigmoid(self.fc_object_identity(x_encoded))

        return self.decoder(z), x_classification_score

variational_autoencoder = VariationalAutoencoder.load_from_checkpoint(
        checkpoint_path=str(checkpoint_file_path)
    )

    def step(self, batch, batch_idx):
        object_image, _object_id, negative_object_image, _negative_object_id, negative_file_name, positive_file_name = batch

        z, object_image_hat, p, q, x_classification_score, o_negative_classification_score = self._run_step(object_image, negative_object_image)


        # 1. Reconstruction loss (MSE)
        recon_loss = F.mse_loss(object_image_hat, object_image, reduction="sum")

        # 2. KL loss (latent distribution vs standard Gaussian)
        kl = torch.distributions.kl_divergence(q, p)
        kl = kl.mean()
        kl *= self.kl_coeff

        # 3. Object identity loss (binary cross entropy)
        x_identity_loss = F.binary_cross_entropy(
            x_classification_score,
            target=torch.ones_like(x_classification_score)
        )
        o_negative_identity_loss = F.binary_cross_entropy(
            o_negative_classification_score,
            target=torch.zeros_like(o_negative_classification_score)
        )

        print(f"o_pos: {positive_file_name[-1]}, o_pos_classification_score: {x_classification_score[-1]}, o_neg_target: 1, o_pos_loss: {x_identity_loss}")
        print(f"o_neg: {negative_file_name[-1]}, o_neg_classification_score: {o_negative_classification_score[-1]}, o_neg_target: 0, o_neg_loss: {o_negative_identity_loss}")

        # TOTAL loss (1 + 2 + 3)
        kl_scaled = kl * 25
        recon_loss_scaled = recon_loss * (1/600000)
        x_identity_loss_scaled = x_identity_loss * 5
        o_negative_identity_loss_scaled = o_negative_identity_loss * 5
        loss = kl_scaled + recon_loss_scaled + x_identity_loss_scaled + o_negative_identity_loss_scaled

        logs = {
            "kl": kl_scaled,
            "recon_loss": recon_loss_scaled,
            "o_identity_loss": x_identity_loss_scaled,
            "negative_o_identity_loss": o_negative_identity_loss,
            "total_loss": loss,
        }
        return loss, logs

    def training_step(self, batch, batch_idx):
        loss, logs = self.step(batch, batch_idx)
        self.log_dict(
            {f"train_{k}": v for k, v in logs.items()},
            on_step=True,
            on_epoch=True
        )
        return loss

# Inferencing
variational_autoencoder.eval()
with torch.no_grad():
    predicted_images, classification_score = variational_autoencoder(training_images)

Results

During last training epoch

002_masterchefcan_elev153_azim66.png (class 1) : classification score = 0.9322
077_rubikscube_elev3_azim28.png (class 0) : classification score = 0.0684

These results look good, clearly classifying correctly.

Inferencing after training

002_masterchefcan_elev153_azim66.png (class 1) : classification score = 0.9229
077_rubikscube_elev3_azim28.png (class 0) : classification score = 0.8279

I would expect these to be the same as the training scores.

aktabit
  • 71
  • 7
  • 1
    What is your dataset pipeline? Any random transformation applied to the input? – Ivan May 28 '22 at 14:21
  • @Ivan the dataset (for the purposes of fixing this issue) is 30 PNG images of class 1 and 30 of class 0. The pipeline is a fixed normalisation of the images which is also applied before inference. There are no random transformations. The same images are being used for training and inference. – aktabit May 28 '22 at 15:38

0 Answers0