0

A task that performs the same task in one dag was created using a for loop. It is hoped to be divided into two branches that depend on the result of this task. However, all tasks created using the for loop return the xcom of the last task. How can tasks created using for loop return each xcom?

Each task a,b,c returns xcom_a, xcom_b, and xcom_c. However, branch tasks all get the same xcom_c. What should I do?

default_args ={'start_date':days_ago(1)}
dag=DAG(
    dag_id='batch_test',
    default_args=default_args,
    schedule_interval=None)

def count(**context):
    name = context['params']['name']
    dict = {'a':50,
        'b':100,
        'c':150}
    if dict[name]<100:
        task_id=f'add_{name}'
        return task_id
    elif dict[name]>=100:
        task_id=f'times_{name}'
        return task_id
def branch(**context):
    task_id = context['ti'].xcom_pull(task_ids=f'count_task_{name}')
    return task_id
def add(**context):
    ans = context['ti'].xcom_pull(task_ids=f'branch_task_{name}')
    ans_dict = {'add_a':50+100,
                'add_b':100+100,
                'add_c':150+100}
    ans = ans_dict[ans]
    return print(ans)
def times(**context):
    ans = context['ti'].xcom_pull(task_ids=f'branch_task_{name}')
    ans_dict = {'times_a':50*100,
            'times_b':100*100,
            'times_c':150*100}
    ans = ans_dict[ans]
    return print(ans)

name_list = ['a','b','c']
for name in name_list:
    exec_count_task = PythonOperator(
            task_id = f'count_task_{name}',
            python_callable = count,
            provide_context=True,
        params = {'name':name},
        dag=dag
        )
    exec_branch_task = BranchPythonOperator(
        task_id = f'branch_task_{name}',
        python_callable = branch,
        provide_context = True,
        dag = dag
        )
    exec_add_count = PythonOperator(
        task_id = f'add_{name}',
        python_callable = add,
        provide_context = True,
        dag = dag
        )
    exec_times_count = PythonOperator(
        task_id = f'times_{name}',
        python_callable = times,
        provide_context = True,
        dag = dag
        )

    exec_count_task >> exec_branch_task >> [exec_add_count, exec_times_count]

i want this...
task_a >> branch_a (branch python operator, xcom pull returned by task_a) >> [task_a1, task_a2]
task_b >> branch_b (branch python operator, xcom pull returned by task_b) >> [task_b1, task_b2]
task_c (>> branch_c (branch python operator, xcom pull returned by task_c) >> [task_c1, task_c2]

but
task_a >> branch_a (branch python operator, xcom pull returned by task_c) >> [task_a1, task_a2]
task_b >> branch_b (branch python operator, xcom pull returned by task_c) >> [task_b1, task_b2]
task_c >> branch_c (branch python operator, xcom pull returned by task_c) >> [task_c1, task_c2]

DG A
  • 15
  • 5

2 Answers2

3

Your functions branch, add, and times don't define name themselves, so it is taken out of global context, which is at time of function execution the last value of for name in name_list. This is a common trap explained e.g. here: tkinter creating buttons in for loop passing command arguments

To fix it, you can either pull name from context as in count, or provide it via op_args or op_kwargs when you create the respective operator, as in the answer by Josh Fell:

        first = PythonOperator(task_id=f"first_task_{i}", python_callable=xcom_push, op_kwargs={"val": i})
        branch = BranchPythonOperator(task_id=f"branch_{i}", python_callable=choose, op_kwargs={"val": i})
  • Indeed, this is a subtle trap. Another way is to pass name as argument to branch, add, times, so it doesn't take it from global context. – Thomas Mar 10 '23 at 15:52
2

I'm unable to reproduce the behavior you describe using classic operators and the TaskFlow API. If you are able to add more context and code of what you are actually executing that would be most helpful.

In the meantime, here are the examples I used should it give you some guidance for troubleshooting. I added a task at the end of the streams to check that the first task indeed pushes its expected value.

Classic Operators

from pendulum import datetime

from airflow.models import DAG
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.utils.trigger_rule import TriggerRule


with DAG(dag_id="multiple_branch_loop", start_date=datetime(2023, 1, 1), schedule=None):
    def xcom_push(val):
        return val

    def func():
        ...

    def choose(val):
        return f"task_{val}"

    def check_xcom_output_from_first(val, expected_val):
        assert val == expected_val

    stuff = ["a", "b", "c"]
    for i in stuff:
        first = PythonOperator(task_id=f"first_task_{i}", python_callable=xcom_push, op_kwargs={"val": i})
        branch = BranchPythonOperator(task_id=f"branch_{i}", python_callable=choose, op_kwargs={"val": i})
        second = PythonOperator(task_id=f"task_{i}", python_callable=func)
        third = PythonOperator(task_id=f"task_{i}a", python_callable=func)
        check = PythonOperator(
            task_id=f"check_{i}",
            trigger_rule=TriggerRule.ALL_DONE,
            python_callable=check_xcom_output_from_first,
            op_kwargs={"val": first.output, "expected_val": i},
        )

        first >> branch >> [second, third] >> check

The check* tasks succeed meaning the first task in a given stream does push its value and not the last stream's. enter image description here

TaskFlow API

from pendulum import datetime

from airflow.decorators import dag, task
from airflow.utils.trigger_rule import TriggerRule


@dag(start_date=datetime(2023, 1, 1), schedule=None)
def multiple_branch_loop():
    @task()
    def xcom_push(val):
        return val

    @task()
    def func():
        ...

    @task.branch()
    def choose(val):
        return f"task_{val}"

    @task(trigger_rule=TriggerRule.ALL_DONE)
    def check_xcom_output_from_first(val, expected_val):
        assert val == expected_val

    stuff = ["a", "b", "c"]
    for i in stuff:
        first = xcom_push.override(task_id=f"first_task_{i}")(val=i)
        branch = choose.override(task_id=f"branch_{i}")(val=first)
        second = func.override(task_id=f"task_{i}")()
        third = func.override(task_id=f"task_{i}a")()
        check = check_xcom_output_from_first.override(task_id=f"check_{i}")(val=first, expected_val=i)

        first >> branch >> [second, third] >> check

multiple_branch_loop()

Same expected behavior as well confirmed in the check* tasks: enter image description here

Josh Fell
  • 2,959
  • 1
  • 4
  • 15
  • Thank you very much for your reply. However, the same problem still occurs. Each task_a, task_b and task_c returns xcom_a, xcom_b and xcom_c. But branch tasks all get xcom_c... – DG A Feb 12 '23 at 10:56