0

I have a DAG that uses KubernetesPodOperator and the task get_train_test_model_task_count in the DAG pushes an xcom variable and I want to use it in the following tasks.

run_this = BashOperator(
    task_id="also_run_this",
    bash_command='echo "ti_key={{ ti.xcom_pull(task_ids=\"get_train_test_model_task_count\", key=\"return_value\")[\"models_count\"] }}"',
)

The above DAG task works and it prints the value as ti_key=24.

I want the same value to be used as a variable,

with TaskGroup("train_test_model_config") as train_test_model_config:
    models_count = "{{ ti.xcom_pull(task_ids=\"get_train_test_model_task_count\", key=\"return_value\")[\"models_count\"] }}"
    print(models_count)
    for task_num in range(0, int(models_count)):
        generate_train_test_model_config_task(task_num)

int(models_count) doesnot work, by throwing the error -

ValueError: invalid literal for int() with base 10: '{{ ti.xcom_pull(task_ids="get_train_test_model_task_count", key="return_value")["models_count"] }}'

And the generate_train_test_model_config_task looks as below:

def generate_train_test_model_config_task(task_num):
    task = KubernetesPodOperator(
        name=f"train_test_model_config_{task_num}",
        image=build_model_image,
        labels=labels,
        cmds=[
            "python3",
            "-m",
            "src.models.train_test_model_config",
            "--tenant=neu",
            f"--model_tag_id={task_num}",
            "--line_plan={{ ti.xcom_pull(key=\"file_name\", task_ids=\"extract_file_name\") }}",
            "--staging_bucket=cs-us-ds"
        ],
        task_id=f"train_test_model_config_{task_num}",
        do_xcom_push=False,
        namespace="airflow",
        service_account_name="airflow-worker",
        get_logs=True,
        startup_timeout_seconds=300,
        container_resources={"request_memory": "29G", "request_cpu": "7000m"},
        node_selector={"cloud.google.com/gke-nodepool": NODE_POOL},
        tolerations=[
            {
                "key": NODE_POOL,
                "operator": "Equal",
                "value": "true",
                "effect": "NoSchedule",
            }
        ],
    )

    return task
Tom J Muthirenthi
  • 3,028
  • 7
  • 40
  • 60
  • Why are you using taskgroup ? can you not used a python operator for this case? – Lucas M. Uriarte Aug 29 '23 at 08:53
  • because you creates the loop when airflow parse the Dag (every 30 sec by default) and jinja is parsed only on runtime (when dag actually triggered). so its a simple str. you need to work with dynamic tasks mapping – ozs Aug 29 '23 at 09:26
  • But it works when I replace `int(models_count)` with any number – Tom J Muthirenthi Aug 29 '23 at 09:28

1 Answers1

1

The Jinja template pulls from the Airflow context which you only can do within a task, not in top level code.

Also as a commenter said you will need to use dynamic task mapping to change the DAG structure dynamically, even if you hardcode the model_num or use another way to template it in, those code changes are only picked up every 30s by the scheduler on default and you have no backwards visibility into previous tasks, for example if one day there are only 2 models you can't see model 3 through 8 in the logs from the day before so it gets a bit messy when using a loop like that even if you can get it to work.

The code below shows the structure that I think will achieve what you want, one model config generated for each task_num. This should work in Airflow 2.3+

@task
def generate_list_of_model_nums(**context):
    model_count = context["ti"].xcom_pull(task_ids="get_train_test_model_task_count", key="return_value")["models_count"]
    return list(range(model_count + 1))


@task
def generate_train_test_model_config_task(task_num):
    # code that generates the model config
    return model_config

model_nums=generate_list_of_model_nums()
generate_train_test_model_config_task.expand(task_num=model_nums)

Notes: I did not test the code above so there might be typos, but this is the general idea, create a list of all the task nums, then use dynamic task mapping to expand over the list.

If you pull the XCom from the generate_train_test_model_config_task you should get a list of all the model configs :)

Some resources that might help to adapt this to traditional operators:

Disclaimer: I work at Astronomer the org who created the guides above :)

EDIT: thanks for sharing the KPO code! I see you are using the task_num in two parameters, this means you can try to use .expand_kwargs over a list of sets of inputs in form of a dictionaries and then map the KPO directly. Note that this is an Airflow 2.4+ feature.

Note on the code: I tested the dict generation function but don't have a K8s cluster running rn so I did not test the latter part, I think name and cmd should be expandable

@task
def generate_list_of_param_dicts(**context):
    model_count = context["ti"].xcom_pull(
        task_ids="get_train_test_model_task_count", key="return_value"
    )["models_count"]

    param_dicts = []
    for i in range(model_count):
        param_dict = {
            "name": f"train_test_model_config_{i}",
            "cmds": [
                "python3",
                "-m",
                "src.models.train_test_model_config",
                "--tenant=neu",
                f"--model_tag_id={i}",
                '--line_plan={{ ti.xcom_pull(key="file_name", task_ids="extract_file_name") }}',
                "--staging_bucket=cs-us-ds",
            ],
        }
        param_dicts.append(param_dict)

    return param_dicts


task = KubernetesPodOperator.partial(
    image=build_model_image,
    labels=labels,
    task_id=f"train_test_model_config",
    do_xcom_push=False,
    namespace="airflow",
    service_account_name="airflow-worker",
    get_logs=True,
    startup_timeout_seconds=300,
    container_resources={"request_memory": "29G", "request_cpu": "7000m"},
    node_selector={"cloud.google.com/gke-nodepool": NODE_POOL},
    tolerations=[
        {
            "key": NODE_POOL,
            "operator": "Equal",
            "value": "true",
            "effect": "NoSchedule",
        }
    ],
).expand_kwargs(generate_list_of_param_dicts())
TJaniF
  • 791
  • 2
  • 7
  • `generate_train_test_model_config_task` is used to create KubernetesPodOperator task with the given task_num as an argument, When used this way `generate_train_test_model_config_task.expand(task_num=generate_list_of_model_nums())` fails with the error `Object of type KubernetesPodOperator is not JSON serializable.` – Tom J Muthirenthi Aug 30 '23 at 11:29
  • I see. Hmm... can you share how you currently define your KubernetesPodOperator (KPO) with the task_num? I think the solution is likely to map the KPO over the input list but I'd need to know which of the KPO parameters takes task_num to know if it is possible :) Something like `KubernetesPodOperator.partial(... all your other parameters...).expand(name= generate_list_of_model_nums())` but you may need to adjust what generate_list_of_model_nums returns to have the right format input for your parameter – TJaniF Aug 30 '23 at 14:28
  • Updated the KPO in the question – Tom J Muthirenthi Aug 30 '23 at 15:44
  • Thank you! Added a suggestion in my answer :) – TJaniF Aug 30 '23 at 16:27