I have implemented dynamic task group mapping with a Python operator and a deferrable operator inside the task group.
I got stuck with controlling the relationship between mapped instance value passed during runtime i.e when the deferrable operator gets into a deferred state it actually trigger the tasks inside the task group for the next mapped instance value perhaps I want to control the trigger of the next mapped instance value after completion of my deferrable operator, Any help with this issue is highly appreciated.
import logging
from airflow.decorators import dag, task_group, task
from airflow.utils.weight_rule import WeightRule
from pendulum import datetime
from airflow.operators.python import get_current_context, PythonOperator
from airflow.utils.context import context_merge
from airflow.triggers.temporal import TimeDeltaTrigger
from datetime import timedelta
from typing import Any, Mapping
from airflow.operators.empty import EmptyOperator
class DeferrableOperator(PythonOperator):
"""
Deferrable Sensor class
"""
def __init__(self, op_kwargs: Mapping[str, Any] | None = None, poke_interval=30, **kwargs):
super().__init__(**kwargs)
self.op_kwargs = op_kwargs or {}
self.poke_interval = poke_interval
def execute(self, context):
"""
Method will be invoked automatically
:param context:
"""
# Unpacking Value like super class
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
logging.info(f"Arguments {self.op_kwargs}")
self.defer(
trigger=TimeDeltaTrigger(delta=timedelta(seconds=self.poke_interval)),
method_name="execute_complete",
)
def execute_complete(self, context, event=None):
"""
Method will be triggered by execute method
:param context:
:param event:
"""
context.update(self.op_kwargs)
# # Invoke the method using the method name string
# num = add(num=self.op_kwargs["num"])
num = self.python_callable(**self.op_kwargs)
if not num:
self.execute(context)
def add(num):
return int(num) + 10
@dag(
start_date=datetime(2022, 12, 1),
schedule=None,
catchup=False,
max_active_tasks=1
)
def task_group_mapping_example3():
@task
def push_xcom():
context = get_current_context()
ti = context["ti"]
ti.xcom_push(key="batches", value=[19, 23, 42])
return [19, 23, 42]
@task_group(group_id="group1")
def tg1(my_num):
@task(weight_rule=WeightRule.ABSOLUTE)
def print_num(num):
return num
print_num = print_num(my_num)
add_num = DeferrableOperator(
task_id="add",
poke_interval=30,
op_kwargs={"num": print_num},
python_callable=add,
weight_rule=WeightRule.ABSOLUTE
)
end = EmptyOperator(task_id="batch_execution_completed", weight_rule=WeightRule.ABSOLUTE, priority_weight=1)
print_num >> add_num >> end
# a downstream task to print out resulting XComs
@task
def pull_xcom(**context):
pulled_xcom = context["ti"].xcom_pull(
# reference a task in a task group with task_group_id.task_id
task_ids=["group1.add_42"],
# only pull Xcom from specific mapped task group instances (2.5 feature)
map_indexes=[2, 3],
key="return_value",
)
# will print out a list of results from map index 2 and 3 of the add_42 task
print(pulled_xcom)
tg1.expand(my_num=push_xcom()) >> pull_xcom()
task_group_mapping_example3()