5

I am trying to write unittests for some of the tasks built with Airflow TaskFlow API. I tried multiple approaches for example, by creating a dagrun or only running the task function but nothing is helping.

Here is a task where I download a file from S3, there is more stuff going on but I removed that for this example.

@task()
def updates_process(files):
    context = get_current_context()
    try:
        updates_file_path = utils.download_file_from_s3_bucket(files.get("updates_file"))
    except FileNotFoundError as e:
        log.error(e)
        return

    # Do something else

Now I was trying to write a test case where I can check this except clause. Following is one the example I started with

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = account_link_updates_process({"updates_file": "path/to/file.csv"})
        get_current_context.assert_called_once()
        log.error.assert_called_once()

I also tried by creating a dagrun as shown in the example here in docs and fetching the task from the dagrun but that also didin't help.

Sadan A.
  • 1,017
  • 1
  • 10
  • 28

2 Answers2

3

I was struggling to do this myself, but I found that the decorated tasks have a .function parameter.

You can then use Task.function() to call the actual function. Using your example:

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = dags.delta_load.updates.updates_process
        # Call the function for testing
        task.function({"updates_file": "path/to/file.csv"})
        get_current_context.assert_called_once()
        log.error.assert_called_once()

This prevents you from having to set up any of the DAG infrastructure and just run the python function as intended!

fbardos
  • 480
  • 1
  • 6
  • 15
AetherUnbound
  • 1,714
  • 11
  • 10
  • Thanks but it didn't work for me. I got error AttributeError: 'function' object has no attribute 'function'. Upon some investigation I can use `task.__wrapped__()` the way you mentioned. – Sadan A. Jul 27 '22 at 10:25
  • 1
    `.function` worked for me too. So in your example it would be `assert updates_process.function(["file1", "file2", ...])` – A H Sep 01 '22 at 03:19
0

This is what I could figure out. Not sure if this is the right thing but it works.

class TestAccountLinkUpdatesProcess(TestCase):
    TASK_ID = "updates_process"

    @classmethod
    def setUpClass(cls) -> None:
        cls.dag = dag_delta_load()

    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = self.dag.get_task(task_id=self.TASK_ID)
        task.op_args = [{"updates_file": "file.csv"}]
        task.execute(context={})
        log.error.assert_called_once()

UPDATE: Based on the answer of @AetherUnbound I did some investigation and found that we can use task.__wrapped__() to call the actual python function.

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        update_process.__wrapped__({"updates_file": "file.csv"})
        log.error.assert_called_once()
Sadan A.
  • 1,017
  • 1
  • 10
  • 28