-1
from tqdm import tqdm

class Engine(object):
    def __init__(self):
        hook_names = ['on_start', 'on_start_epoch', 'on_sample', 'on_forward',
                      'on_backward', 'on_end_epoch', 'on_update', 'on_end']

        self.hooks = { }
        for hook_name in hook_names:
            self.hooks[hook_name] = lambda state: None # The definition place

    def train(self, **kwargs):
        state = {
            'model': kwargs['model'],
            'loader': kwargs['loader'],
            'optim_method': kwargs['optim_method'],
            'optim_config': kwargs['optim_config'],
            'max_epoch': kwargs['max_epoch'],
            'epoch': 0, # epochs done so far
            't': 0, # samples seen so far
            'batch': 0, # samples seen in current epoch
            'stop': False
        }

        state['optimizer'] = state['optim_method'](state['model'].parameters(), **state['optim_config'])

        self.hooks['on_start'](state)             # Called this place, What's the meaning?
        while state['epoch'] < state['max_epoch'] and not state['stop']:
            state['model'].train()

            self.hooks['on_start_epoch'](state)

            state['epoch_size'] = len(state['loader'])

            for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
                state['sample'] = sample
                self.hooks['on_sample'](state)   # Called the second place

                state['optimizer'].zero_grad()
                loss, state['output'] = state['model'].loss(state['sample'])
                self.hooks['on_forward'](state)   # Called

                loss.backward()
                self.hooks['on_backward'](state)   # Called

                state['optimizer'].step()

                state['t'] += 1
                state['batch'] += 1
                self.hooks['on_update'](state)      # Called

            state['epoch'] += 1
            state['batch'] = 0
            self.hooks['on_end_epoch'](state)

        self.hooks['on_end'](state)          #  Called

I'm almost new to python and deeplearning, the code above is a piece in prototypical network(https://github.com/jakesnell/prototypical-networks/blob/c9bb4d258267c11cb6e23f0a19242d24ca98ad8a/protonets/engine.py). The lambda function annotated is so strange to me. I tried to print type of (self.hooks'on_start') to help me realize something but i only got "None". I am so confused, can someone explain what's it used for in the code above?

  • Are you asking [what a lambda function is](https://stackoverflow.com/q/16501/11082165), or are your asking why the author is creating lambda functions that just return `None`? – Brian61354270 Dec 07 '22 at 15:38
  • Are you sure there isn't any other function that fills in your hooks ? Because hooks here are functions initialized to return None. `type(lambda state: None)` returns `function`. – 0x0fba Dec 07 '22 at 15:43
  • Following your tip, I found functions that fill in these hooks in another file, and I think I kinda get it! Thank you very much! – xiangguo SHEN Dec 07 '22 at 16:06
  • @xiangguo SHEN You should post the answer to help others with the same question. – kmkurn Dec 08 '22 at 03:01
  • @kmkurn In fact 0x0fba has answered this question in the comments, or did you mean that I should put an answer by clicking "Answer Your Question" but not in comment, I am also new to stackoverflow. – xiangguo SHEN Dec 08 '22 at 03:28
  • Yes that's what I meant. You had asked what the lambda function is used for in the code, and then you said you kinda get it. I think it'll help others if you share what you get as an answer because 0x0fba's comment alone isn't clear enough (at least to me). – kmkurn Dec 08 '22 at 04:33
  • @Brian The problem i was asking is why the author is creating lambda functions that just return None, thanks for your comment! – xiangguo SHEN Dec 08 '22 at 12:25
  • @kmkurn OK, I have done this – xiangguo SHEN Dec 08 '22 at 12:26

1 Answers1

1

I was confused about why the author defined a function that does nothing and calls it many times.

The lambda function which returned None is in the class's __init__ method - it's used to initialize functions in self.hooks.

When in use, the author reassigns a new function to each key in hooks, so when hooks[key] is called it will not just do nothing. The code reassigns a new function in another python file and I didn't find it.

AlexK
  • 2,855
  • 9
  • 16
  • 27