I have trained a Variational Autoencoder (VAE) with an additional fully connected layer after the encoder for binary image classification. It is setup using PyTorch Lightning. The encoder / decoder is resnet18
from PyTorch Lightning Bolts repo.
During training, I am printing out the classification scores. I can see the scores of the class 1 images tend towards 1 and the class 0 images tend towards 0 as training progresses. The loss also reduces as expected and reaches a local minimum.
However, the classification scores I am getting when inferencing are completely different from the ones achieved during training, even for the same images e.g. the training set.
It also appears the longer I run training, the higher the scores get when inferencing (even for images that are class = 0). So eventually all images from the training set score >0.98.
Why are the inference results different from training?
UPDATE #1:
It appears to give reasonable inference results if I don't set model.eval()
and pass a single image. Multiple images don't work well due to this. My understanding of model.eval()
was it helps achieve consistent results not this behaviour?
Model
from pl_bolts.models.autoencoders.components import (
resnet18_encoder,
resnet18_decoder
)
class VariationalAutoencoder(LightningModule):
...
self.first_conv: bool = False
self.maxpool1: bool = False
self.enc_out_dim: int = 512
self.encoder = resnet18_encoder(first_conv, maxpool1)
self.fc_object_identity = nn.Linear(self.enc_out_dim, 1)
...
def _run_step(self, object_image, negative_object_image):
# 1. Reconstruction model
object_image_encoded = self.encoder(object_image)
mu = self.fc_mu(object_image_encoded)
log_var = self.fc_var(object_image_encoded)
p, q, z = self.sample(mu, log_var)
# 2. Classification model (includes negative observation)
negative_object_image_encoded = self.encoder(negative_object_image)
object_image_classification_score = torch.sigmoid(self.fc_object_identity(object_image_encoded))
negative_object_classification_score = torch.sigmoid(self.fc_object_identity(negative_object_image_encoded))
return z, self.decoder(z), p, q, object_image_classification_score, negative_object_classification_score
def forward(self, x):
x_encoded = self.encoder(x)
mu = self.fc_mu(x_encoded)
log_var = self.fc_var(x_encoded)
p, q, z = self.sample(mu, log_var)
x_classification_score = torch.sigmoid(self.fc_object_identity(x_encoded))
return self.decoder(z), x_classification_score
variational_autoencoder = VariationalAutoencoder.load_from_checkpoint(
checkpoint_path=str(checkpoint_file_path)
)
def step(self, batch, batch_idx):
object_image, _object_id, negative_object_image, _negative_object_id, negative_file_name, positive_file_name = batch
z, object_image_hat, p, q, x_classification_score, o_negative_classification_score = self._run_step(object_image, negative_object_image)
# 1. Reconstruction loss (MSE)
recon_loss = F.mse_loss(object_image_hat, object_image, reduction="sum")
# 2. KL loss (latent distribution vs standard Gaussian)
kl = torch.distributions.kl_divergence(q, p)
kl = kl.mean()
kl *= self.kl_coeff
# 3. Object identity loss (binary cross entropy)
x_identity_loss = F.binary_cross_entropy(
x_classification_score,
target=torch.ones_like(x_classification_score)
)
o_negative_identity_loss = F.binary_cross_entropy(
o_negative_classification_score,
target=torch.zeros_like(o_negative_classification_score)
)
print(f"o_pos: {positive_file_name[-1]}, o_pos_classification_score: {x_classification_score[-1]}, o_neg_target: 1, o_pos_loss: {x_identity_loss}")
print(f"o_neg: {negative_file_name[-1]}, o_neg_classification_score: {o_negative_classification_score[-1]}, o_neg_target: 0, o_neg_loss: {o_negative_identity_loss}")
# TOTAL loss (1 + 2 + 3)
kl_scaled = kl * 25
recon_loss_scaled = recon_loss * (1/600000)
x_identity_loss_scaled = x_identity_loss * 5
o_negative_identity_loss_scaled = o_negative_identity_loss * 5
loss = kl_scaled + recon_loss_scaled + x_identity_loss_scaled + o_negative_identity_loss_scaled
logs = {
"kl": kl_scaled,
"recon_loss": recon_loss_scaled,
"o_identity_loss": x_identity_loss_scaled,
"negative_o_identity_loss": o_negative_identity_loss,
"total_loss": loss,
}
return loss, logs
def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
self.log_dict(
{f"train_{k}": v for k, v in logs.items()},
on_step=True,
on_epoch=True
)
return loss
# Inferencing
variational_autoencoder.eval()
with torch.no_grad():
predicted_images, classification_score = variational_autoencoder(training_images)
Results
During last training epoch
002_masterchefcan_elev153_azim66.png (class 1) : classification score = 0.9322
077_rubikscube_elev3_azim28.png (class 0) : classification score = 0.0684
These results look good, clearly classifying correctly.
Inferencing after training
002_masterchefcan_elev153_azim66.png (class 1) : classification score = 0.9229
077_rubikscube_elev3_azim28.png (class 0) : classification score = 0.8279
I would expect these to be the same as the training scores.