5

I'm using Hydra for training machine learning models. It's great for doing complex commands like python train.py data=MNIST batch_size=64 loss=l2. However, if I want to then run the trained model with the same parameters, I have to do something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml. I then use argparse to load in the previous yaml and use the compose API to initialize the Hydra environment. The path to the trained model is inferred from the path to Hydra's .yaml file. If I want to modify one of the parameters, I have to add additional argparse parameters and run something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128. The code then manually overrides any Hydra parameters with those that were specified on the command line.

What's the right way of doing this?

My current code looks something like the following:

train.py:

import hydra

@hydra.main(config_name="config", config_path="conf")
def main(cfg):
    # [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
    # [code outputs model checkpoint to job folder generated by Hydra]
main()

reconstruct.py:

import argparse
import os
from hydra.experimental import initialize, compose

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('hydra_config')
    parser.add_argument('--batch_size', type=int)
    # [other flags and parameters I may need to override]
    args = parser.parse_args()

    # Create the Hydra environment.
    initialize()
    cfg = compose(config_name=args.hydra_config)

    # Since checkpoints are stored next to the .hydra, we manually generate the path.
    checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))

    # Manually override any parameters which can be changed on the command line.
    batch_size = args.batch_size if args.batch_size else cfg.data.batch_size

    # [code which uses checkpoint_dir to load the model]
    # [code which uses both batch_size and params in cfg to set up the data etc.]

This is my first time posting, so let me know if I should clarify anything.

Rachel
  • 53
  • 3

1 Answers1

4

If you want to load the previous config as is and not change it, use OmegaConf.load(file_path).

If you want to re-compose the config (and it sounds like you do, because you added that you want override things), I recommend that you use the Compose API and pass in parameters from the overrides file in the job output directory (next to the stored config.yaml), but concatenate the current run parameters.

This script seems to be doing the job:

import os
from dataclasses import dataclass
from os.path import join
from typing import Optional

from omegaconf import OmegaConf

import hydra
from hydra import compose
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import to_absolute_path

# You can also use a yaml config file instead of this Structured Config
@dataclass
class Config:
    load_checkpoint: Optional[str] = None
    batch_size: int = 16
    loss: str = "l2"


cs = ConfigStore.instance()
cs.store(name="config", node=Config)


@hydra.main(config_path=".", config_name="config")
def my_app(cfg: Config) -> None:

    if cfg.load_checkpoint is not None:
        output_dir = to_absolute_path(cfg.load_checkpoint)
        original_overrides = OmegaConf.load(join(output_dir, ".hydra/overrides.yaml"))
        current_overrides = HydraConfig.get().overrides.task

        hydra_config = OmegaConf.load(join(output_dir, ".hydra/hydra.yaml"))
        # getting the config name from the previous job.
        config_name = hydra_config.hydra.job.config_name
        # concatenating the original overrides with the current overrides
        overrides = original_overrides + current_overrides
        # compose a new config from scratch
        cfg = compose(config_name, overrides=overrides)

    # train
    print("Running in ", os.getcwd())
    print(OmegaConf.to_yaml(cfg))


if __name__ == "__main__":
    my_app()
~/tmp$ python train.py 
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-13
load_checkpoint: null
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-22
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13 batch_size=32
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-28
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 32
loss: l2
Omry Yadan
  • 31,280
  • 18
  • 64
  • 87
  • 1
    This approach works great! I wish it weren't quite so complex, but I think I can just wrap it in a helper function and bundle my new `config_utils.py` with the config files themselves. Thanks for the help! – Rachel Apr 23 '21 at 17:59
  • There could be many ways to load the stored config. this one is a very flexible but a bit complicated. for many people, just using `cfg = OmegaConf.load(cfg.config_path)` is enough. I hope this will help others. – Omry Yadan Apr 23 '21 at 22:47