I have code using pyspark
library and I want to test it with pytest
However, I want to mock up .repartition()
method on dataframes when running tests
- Suppose that code I want to test is a pyspark chained function like below
def transform(df: pyspark.sql.DataFrame):
return (
df
.repartition("id")
.groupby("id")
.sum("quantity")
)
- Currently my testing function looks like
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
- Now, how can I mock up
.repartition()
method for my test ? Something like this pseudo-code (currently not working)
from unittest import mock
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df