9

How to transform vectors of labels to one-hot encoding and back in Pytorch?

The solution to the question was copied to here after having to go through the entire forum discussion, instead of just finding an easy one from googling.

Gulzar
  • 23,452
  • 27
  • 113
  • 201

2 Answers2

14

From the Pytorch forums

import torch
import numpy as np


labels = torch.randint(0, 10, (10,))

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)

np.testing.assert_equals(labels.numpy(), labels_again.numpy())
Gulzar
  • 23,452
  • 27
  • 113
  • 201
  • 2
    Notice [this answer](https://stackoverflow.com/a/74784150/913098) for cases where you have to specify the number of classes [which is most of the time] – Gulzar Dec 14 '22 at 12:45
7

Since I can't comment on the accepted answer, I just wanted to add that if your target does not include all classes (e.g. because you train in batches), you can specify the number of classes as argument:

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(target, num_classes=7)
swageta
  • 115
  • 1
  • 7