-2

I recently came across this piece of code, when trying to implement a transformer using Keras.

class Transformer(keras.Model):
    def __init__(
        self,
        num_hid=64,
        num_head=2,
        num_feed_forward=128,
        source_maxlen=100,
        target_maxlen=100,
        num_layers_enc=4,
        num_layers_dec=1,
        num_classes=60,
    ):
        super().__init__()
    
    ...

    def train_step(self, batch):

        ....

        with tf.GradientTape() as tape:
            preds = self([source, dec_input]) #problematic line
            one_hot = tf.one_hot(dec_target, depth=self.num_classes)
            mask = tf.math.logical_not(tf.math.equal(dec_target, pad_token_idx))
            loss = self.compiled_loss(one_hot, preds, sample_weight=mask)

Here, how are they using self() inside the class Transformer, in the method train_step. Which class is it initializing?

Is it initializing class GradientTape? If so, how does it work?

Is it initializing Model class? If so, Is this allowed in python to initialize parent class inside method of child class using self? How does it work?

I tried this.

class A:
    def __init__(self, x=None, y=None):
        print("A:",x,y)

class B(A):
    def __init__(self):
        super().__init__()
    def call(self, x, y):
        c = self(x,y)

b = B()
b.call(1,2)

Output is:

A: None None


Traceback (most recent call last):
  File "./Playground/file0.py", line 12, in <module>
    b.call(1,2)
  File "./Playground/file0.py", line 9, in call
    c = self(x,y)
TypeError: 'B' object is not callable

As expected it doesn't work. Cannot create new object using self() inside object.

1 Answers1

0

You are confusing constructors (__init__) and call functions (__call__). B is the class type, B() constructs an instance using __init__. b is an instance, b() "calls" using __call__.

self(*args) is the same as self.__call__(*args):

class A:
    def __init__(self, x=None, y=None):
        print("A:",x,y)

class B(A):
    def __init__(self):
        super().__init__()
    def call(self, x, y):
        c = self(x,y)
    def __call__(self, x, y):
        print("B:", x, y)

b = B()
b.call(1,2)
b(3,4)

output:

A: None None
B: 1 2
B: 3 4

In your example you didn't define __call__ hence the error message.

Julien
  • 13,986
  • 5
  • 29
  • 53