1

I cannot figure out what kind of python sytanx is used. I hope someone could tell me.

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T
from torchvision.io import read_image


plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)


def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

The part I'm wondering is below.

        img = T.ToPILImage()(img.to('cpu'))

ToPILImage() class instantiating with an image object, not as its argument but outside the brackets.

The code is from this page.
https://pytorch.org/vision/stable/auto_examples/plot_scripted_tensor_transforms.html#sphx-glr-auto-examples-plot-scripted-tensor-transforms-py

The source code of ToPILImage() is from the page link below.
https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#ToPILImage

On that page, three special methods (__init__, __call__, __repr__) for the class are defined only.

class ToPILImage:
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).

    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
    """

    def __init__(self, mode=None):
        _log_api_usage_once(self)
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        if self.mode is not None:
            format_string += f"mode={self.mode}"
        format_string += ")"
        return format_string

My guess is

  • the __init__ is called in that process.
  • object cast is in the process.

But I don't know how those two are linked in that call ?

Dummy
  • 51
  • 1
  • 7
  • There's no object casting going on. The `ToPILImage` object is simply calling it's magic `__call__` method. It's the magic function that's called when trying to run an object as a function. – Abirbhav G. Oct 19 '22 at 17:32
  • @AbirbhavG. oh .. it was not `__init__` but `__call__`. Never used the special method. – Dummy Oct 19 '22 at 17:44
  • @Woodford thank you for the reference ! I throughly read it. With that reference and the answer from Alec Petersen cleared it up ! – Dummy Oct 19 '22 at 17:46

2 Answers2

1
img = T.ToPILImage()(img.to('cpu'))

T.ToPILImage() must return some callable object (likely a class).

Then this object is called, with img.to('cpu') as an argument.

If it helps, think of the code this way:

myclass = T.ToPILImage()
img = myclass(img.to('cpu'))
John Gordon
  • 29,573
  • 7
  • 33
  • 58
1

It's not instantiating with the image object as an argument. The class is instantiated with T.ToPILImage() (this is when __init__ is run), it is then called (hence the implementation of the __call__ method) with the image as an argument (like a function).

A simple example:

class Add:
    def __call__(self, x, y):
        return x + y

adder = Add()
summed = adder(1, 3)
print(summed)