77

I would like to create a conditional task in Airflow as described in the schema below. The expected scenario is the following:

  • Task 1 executes
  • If Task 1 succeed, then execute Task 2a
  • Else If Task 1 fails, then execute Task 2b
  • Finally execute Task 3

Conditional Task All tasks above are SSHExecuteOperator. I'm guessing I should be using the ShortCircuitOperator and / or XCom to manage the condition but I am not clear on how to implement that. Could you please describe the solution?

Alexis.Rolland
  • 5,724
  • 6
  • 50
  • 77

3 Answers3

90

Airflow 2.x

Airflow provides a branching decorator that allows you to return the task_id (or list of task_ids) that should run:

@task.branch(task_id="branch_task")
def branch_func(ti):
    xcom_value = int(ti.xcom_pull(task_ids="start_task"))
    if xcom_value >= 5:
        return "big_task" # run just this one task, skip all else
    elif xcom_value >= 3:
        return ["small_task", "warn_task"] # run these, skip all else
    else:
        return None # skip everything

You can also inherit directly from BaseBranchOperator overriding the choose_branch method, but for simple branching logic the decorator is best.

Airflow 1.x

Airflow has a BranchPythonOperator that can be used to express the branching dependency more directly.

The docs describe its use:

The BranchPythonOperator is much like the PythonOperator except that it expects a python_callable that returns a task_id. The task_id returned is followed, and all of the other paths are skipped. The task_id returned by the Python function has to be referencing a task directly downstream from the BranchPythonOperator task.

If you want to skip some tasks, keep in mind that you can’t have an empty path, if so make a dummy task.

Code Example

def dummy_test():
    return 'branch_a'

A_task = DummyOperator(task_id='branch_a', dag=dag)
B_task = DummyOperator(task_id='branch_false', dag=dag)

branch_task = BranchPythonOperator(
    task_id='branching',
    python_callable=dummy_test,
    dag=dag,
)

branch_task >> A_task 
branch_task >> B_task

If you're installing an Airflow version >=1.10.3, you can also return a list of task ids, allowing you to skip multiple downstream paths in a single Operator and don't have to use a dummy task before joining.

villasv
  • 6,304
  • 2
  • 44
  • 78
  • do you have more details about "return a list of task ids, allowing you to skip multiple downstream paths in a single Operator:" – mr4kino Mar 08 '19 at 13:45
  • 2
    @mr4kino Oops looks like it was postponed until 1.10.3, I was too early on that comment ;-) Will update the answer, thanks. – villasv Mar 08 '19 at 14:07
  • @alltej not sure what you mean, but A_task and B_task can be any operator you want (also branch_x on the multibranch example). `DummyOperator` was just a silly example. It's called `BranchPythonOperator` because it uses a Python function to decide what branch to follow, nothing more. – villasv Mar 31 '20 at 22:36
  • If the branch is using a `KubernetesPodOperator` that extract some files and let us say there are no files to extract, I need to mark that task and the downstream tasks as 'Skipped'. – alltej Mar 31 '20 at 22:49
  • If the skipping condition comes from inside an Operator, I suggest using an XCOM and have a `BranchPythonOperator` decide based on that XCOM value. In particular for the `KubernetesPodOperator`, you might want to use `xcom_push=True` to send that status. – villasv Mar 31 '20 at 23:07
60

You have to use airflow trigger rules

All operators have a trigger_rule argument which defines the rule by which the generated task get triggered.

The trigger rule possibilities:

ALL_SUCCESS = 'all_success'
ALL_FAILED = 'all_failed'
ALL_DONE = 'all_done'
ONE_SUCCESS = 'one_success'
ONE_FAILED = 'one_failed'
DUMMY = 'dummy'

Here is the idea to solve your problem:

from airflow.operators.ssh_execute_operator import SSHExecuteOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.contrib.hooks import SSHHook

sshHook = SSHHook(conn_id=<YOUR CONNECTION ID FROM THE UI>)

task_1 = SSHExecuteOperator(
        task_id='task_1',
        bash_command=<YOUR COMMAND>,
        ssh_hook=sshHook,
        dag=dag)

task_2 = SSHExecuteOperator(
        task_id='conditional_task',
        bash_command=<YOUR COMMAND>,
        ssh_hook=sshHook,
        dag=dag)

task_2a = SSHExecuteOperator(
        task_id='task_2a',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ALL_SUCCESS,
        ssh_hook=sshHook,
        dag=dag)

task_2b = SSHExecuteOperator(
        task_id='task_2b',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ALL_FAILED,
        ssh_hook=sshHook,
        dag=dag)

task_3 = SSHExecuteOperator(
        task_id='task_3',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ONE_SUCCESS,
        ssh_hook=sshHook,
        dag=dag)


task_2.set_upstream(task_1)
task_2a.set_upstream(task_2)
task_2b.set_upstream(task_2)
task_3.set_upstream(task_2a)
task_3.set_upstream(task_2b)
Jean S
  • 756
  • 6
  • 5
  • Thank you @Jean S your solution works like a charm. I have one more question. In a scenario where Task2a is executed and Task2b is skipped, I noticed Task3 is executed in the same time as Task2a, while I would like to execute it after. Would you have a trick for this other than duplicating Task3 in 2 branches (like Task3a and Task3b). Thanks again. – Alexis.Rolland May 12 '17 at 03:13
  • 3
    Hi! did you try to change : trigger_rule=TriggerRule.ONE_SUCCESS by trigger_rule=TriggerRule.ALL_DONE in TASK 3 ? Are you sure that your tasks are executed at the same time ? (try to put a sleep function in T2A to sanity check) – Jean S May 15 '17 at 13:45
  • 3
    From Airflow's documentation here [link](https://airflow.incubator.apache.org/concepts.html#trigger-rules) I confirm that "one_success: fires as soon as **at least one parent** succeeds, **it does not wait for all parents to be done**"... I will try with ALL_DONE! Thank you – Alexis.Rolland May 15 '17 at 17:59
  • Is anyone else trying something like this, but getting a deprecation error? – Reid Aug 11 '17 at 20:51
  • 4
    Failure seems a bit too broad. A task could fail for all sorts of reasons ( network or DNS issues for example) and then trigger the wrong downstream task. Is there a way to define two or more different types of success with two different downstream options? e.g. file exists do a, file doesn't exist do b? File sensor doesn't seem to be the right answer, because after all the retries, failure could be for other reasons. – Davos Oct 08 '17 at 15:26
  • 1
    For anyone else looking for the new trigger rules documentation (Airflow 2.1+), you can find it here: [Trigger Rules](https://airflow.apache.org/docs/apache-airflow/stable/concepts/dags.html#trigger-rules) – yanniskatsaros May 28 '21 at 18:23
2

Let me add my take on this.

First of all, sorry for the lengthy post, but I wanted to share the complete solution that works for me.

background

We have a script that pulls data from a very crappy and slow API. It's slow so we need to be selective about what we do and what we don't pull from it (1 request/s with more than 750k requests to make) Occasionally the requirements change that forces us to pull the data in full but only for one/some endpoints. So we need something we can control.

The strict rate limit of 1 request/s with several seconds of delay if breached would halt all parallel tasks.

The meaning of the 'catchup': True is essentially a backfill that is translated into a command line option (-c).

There are no data dependencies between our tasks, we only need to follow the order of (some) tasks.

solution

Introducing the pre_execute callable with the extra DAG config takes care of the proper skip of tasks which throws the AirflowSkipException.

Secondly, based on the config we can swap the original operator for a simple Python operator with the same name with a simple definition. This way the UI won't be confused and the trigger history will be kept complete - showing the executions when a task was skipped.

from airflow import DAG
from airflow.exceptions import AirflowSkipException
from airflow.operators.python import PythonOperator

from plugins.airflow_utils import default_args, kubernetes_pod_task


# callable for pre_execute arg
def skip_if_specified(context):
    task_id = context['task'].task_id
    conf = context['dag_run'].conf or {}
    skip_tasks = conf.get('skip_task', [])
    if task_id in skip_tasks:
        raise AirflowSkipException()

# these are necessary to make this solution work
support_task_skip_args = {'pre_execute': skip_if_specified,
                          'trigger_rule': 'all_done'}
extended_args = {**default_args, **support_task_skip_args}

dag_name = 'optional_task_skip'

dag = DAG(dag_name,
          max_active_runs=3,
          schedule_interval=None,
          catchup=False,
          default_args=extended_args)

# select endpoints and modes
# !! make sure the dict items are in the same order as the order you want them to run !!
task_options = {
    'option_name_1':
        {'param': 'fetch-users', 'enabled': True, 'catchup': False},
    'option_name_2':
        {'param': 'fetch-jobs', 'enabled': True},
    'option_name_3':
        {'param': 'fetch-schedules', 'enabled': True, 'catchup': True},
    'option_name_4':
        {'param': 'fetch-messages', 'enabled': True, 'catchup': False},
    'option_name_5':
        {'param': 'fetch-holidays', 'enabled': True, 'catchup': False},
}


def add_tasks():
    task_list_ = []
    for task_name_, task_config_ in task_options.items():
        if task_config_['enabled']:
            parameter_ = task_config_['param']
            catchup_ = '-c ' if task_config_.get('catchup') else ''
            task_list_.append(
                kubernetes_pod_task(
                    dag=dag,
                    command=f"cd people_data; python3 get_people_data.py {parameter_} {catchup_}",
                    task_id=f"{task_name_}"))
            if len(task_list_) > 1:
                task_list_[-2] >> task_list_[-1]
        else:
            # the callable that throws the skip signal
            def skip_task(): raise AirflowSkipException()

            task_list_.append(
                PythonOperator(dag=dag,
                               python_callable=skip_task,
                               task_id=f"{task_name_}",
                               )
            )
            if len(task_list_) > 1:
                task_list_[-2] >> task_list_[-1]


# populate the DAG
add_tasks()

Note: The default_args, kubernetes_pod_task are just wrappers for convenience. The kubernetes pod task injects some variables and secrets in a simple function and uses the from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator module, I won't and can't share those with you.

The solution extends the great ideas of this gentleman:
https://www.youtube.com/watch?v=abLGyapcbw0

Although, this solution works with Kubernetes operators, too.

Of course, this could be improved, and you absolutely can extend or rework the code to parse manual trigger config as well (as it is shown in the video).

Here's what it looks like in my UI : enter image description here

(it doesn't reflect the example config above but rather the actual runs in our staging infrastructure)

Gergely M
  • 583
  • 4
  • 11