2

I recently started using Metaflow for my hyperparameter searches. I'm using a foreach for all my parameters as follows:

from metaflow import FlowSpec, step

@step
def start_hpo(self):
    self.next(self.train_model, foreach='hpo_parameters')

@step
def train_model(self):
    # Trains model...

This works as it starts the step train_model as intended but unfortunately it wants to parallelise all steps at once. This causes my gpu / cpu to run out of memory instantly failing the step.

Is there a way to tell metaflow to do these steps linearly / one at a time instead or another workaround?

Thanks

BBQuercus
  • 819
  • 1
  • 11
  • 28

2 Answers2

2

@BBQuercus You can limit parallelization by using the --max-workers flag.

Currently, we run no more than 16 tasks in parallel and you can override it as python myflow.py run --max-workers 32 for example.

Savin
  • 141
  • 6
  • is there a way to specify this for a particular step, or customize the `max-workers` in different steps launched by `foreach`? – themantalope Apr 05 '20 at 19:41
  • 1
    Not currently, we just opened a Github issue tracking the feature request - https://github.com/Netflix/metaflow/issues/172 – Savin Apr 06 '20 at 21:30
0

As mentioned, you can control this at a flow-level using the --max-workers flag.

To permanently override the --max-workers flag for a flow, here is a decorator. This decorator can also be used to override other Metaflow flags as well, such as --max-num-splits.

def fix_cli_args(**kwargs: Dict[str, str]):
    """
    Decorator to override Metaflow CLI arguments.

    Usage:
        @fix_cli_args(**{"--max-workers": "1", "--max-num-splits": "100"})
        class InferencePipeline(FlowSpec): ...

    Warnings:
        If the argument is specified by the user, it will be overridden by the value specified in the decorator and a
        warning will be raised.
    """

    def decorator(pipeline):
        def wrapper():
            if "run" not in sys.argv and "resume" not in sys.argv:
                # ignore this decorator if we are not running or resuming a flow
                return pipeline()
            for arg, val in kwargs.items():
                if arg in sys.argv:  # if arg was passed, override it
                    ind = sys.argv.index(arg)
                    logger.warning(f"`{arg}` arg was passed with value `{sys.argv[ind + 1]}`. However, this value will"
                                   f"be overriden by @fix_cli_args with value {val}")
                    sys.argv[ind + 1] = val  # replace the val
                else:  # otherwise, add (arg, val) to the call
                    sys.argv.extend([arg, val])
            logger.info(f"Fixed CLI args for {kwargs.keys()}")
            return pipeline()

        return wrapper

    return decorator
crypdick
  • 16,152
  • 7
  • 51
  • 74