4

I don't understand this line:

lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

There is no comment, so is it some well-known Python (or PyTorch?) idiom? Could someone explain what it means, or show a different way that makes the intent clearer?

lprobs is a pytorch Tensor, and it could contain any size float type (I doubt this code is intended to support int or complex types). As far as I know, the Tensor classes don't override the __ne__ function.

jonrsharpe
  • 115,751
  • 26
  • 228
  • 437
Darren Cook
  • 27,837
  • 13
  • 117
  • 217
  • 3
    It looks like `lprobs[lprobs != lprobs]` is selecting those elements of `lprobs` that contain NaN. Then it sets those to negative infinity...? – khelwood May 12 '21 at 08:11
  • I'd guess a Tensor is based on e.g. a numpy matrix, this is the idiom: https://numpy.org/doc/stable/user/basics.indexing.html#boolean-or-mask-index-arrays – jonrsharpe May 12 '21 at 08:13
  • @jonrsharpe I'm familiar with the `x[b]` idea. I'm confused by `x[x!=x]`. – Darren Cook May 12 '21 at 08:18
  • This _is_ `x[b]`, where `b = x != x`. In the linked example the comparison (see also https://numpy.org/doc/stable/user/basics.rec.html?highlight=comparison#structure-comparison) that creates the mask is `y>20`. – jonrsharpe May 12 '21 at 08:20
  • see [this](https://stackoverflow.com/questions/1565164/what-is-the-rationale-for-all-comparisons-returning-false-for-ieee754-nan-values) – Gulzar May 12 '21 at 08:22
  • 2
    Does this answer your question? [In what situation is an object not equal to itself?](https://stackoverflow.com/questions/59253497/in-what-situation-is-an-object-not-equal-to-itself) – iacob May 12 '21 at 09:03

1 Answers1

10

It's a combination of fancy indexing with a boolean mask, and a "trick" (although intended by design) to check for NaN: x != x holds iff x is NaN (for floats, that is).

They could alternatively have written

lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)

or, probably even more idiomatically, used torch.nan_to_num (but beware that the latter also has special behaviour towards infinities).

A non-updating variant of the above would be

torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
phipsgabler
  • 20,535
  • 4
  • 40
  • 60
  • 1
    Thanks. I'm used to using `is.nan` in R, and `isNaN()` in JavaScript, so naturally I like your alternative version much more! At first glance `torch.nan_to_num()` looks really nice and descriptive, but it also replaces -inf and +inf, which (IMHO) completely spoils the good idea (the function name is misleading rather than descriptive, and a possible efficiency concern too). – Darren Cook May 12 '21 at 10:01
  • Uh, right. I only stumbled upon `nan_to_num` today and didn't realize it has that kind of behaviour. – phipsgabler May 12 '21 at 11:49