2

I have code using pyspark library and I want to test it with pytest

However, I want to mock up .repartition() method on dataframes when running tests

  1. Suppose that code I want to test is a pyspark chained function like below
def transform(df: pyspark.sql.DataFrame):
    return (
       df
       .repartition("id")
       .groupby("id")
       .sum("quantity")
    )
  1. Currently my testing function looks like
@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
def test_transform(df, expected_df):
    df_output = transform(df)
    assert df_output == expected_df
  1. Now, how can I mock up .repartition() method for my test ? Something like this pseudo-code (currently not working)
from unittest import mock

@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
    df_output = transform(df)
    assert df_output == expected_df
Henry8
  • 110
  • 12

1 Answers1

3

Please chain calls like below. See here similar one

@mock.patch("pyspark.sql.DataFrame")
def test_transform(df: Mock):
    expected_df = "expected value"
    df.repartition.return_value.groupby.return_value.sum.return_value = expected_df
    df_output = transform(df)
    assert df_output == expected_df
    df.repartition.assert_called_with("id")
    df.repartition().groupby.assert_called_with("id")
    df.repartition().groupby().sum.assert_called_with("quantity")

Anton
  • 1,432
  • 13
  • 17
  • Thanks, it works so i give you the bounty. Unfortunately my real function was more steps of dataframes manipulation with complex chained functions. So I eventually just came up fixing execution time by setting SparkConf `("spark.sql.shuffle.partitions", 1)` and `("spark.default.parallelism", 1)` – Henry8 Nov 06 '22 at 09:49