53

My current Java/Spark Unit Test approach works (detailed here) by instantiating a SparkContext using "local" and running unit tests using JUnit.

The code has to be organized to do I/O in one function and then call another with multiple RDDs.

This works great. I have a highly tested data transformation written in Java + Spark.

Can I do the same with Python?

How would I run Spark unit tests with Python?

Community
  • 1
  • 1
pettinato
  • 1,472
  • 2
  • 19
  • 39
  • 3
    you can do the same thing with pySpark and using unittest module. The project's tests themselves use this module : https://github.com/apache/spark/blob/master/python/pyspark/tests.py – Paul K. Nov 19 '15 at 19:12
  • 3
    [pytest](https://docs.pytest.org/en/stable/) + [chispa](https://github.com/MrPowers/chispa) make it easy to unit test PySpark code. Avoid unittest. chispa is the native PySpark port of [spark-fast-tests](https://github.com/MrPowers/spark-fast-tests/). See my answer for more details. – Powers Sep 14 '20 at 19:22
  • 1
    @PaulK. Hi the link you shared is invalid :) – wawawa Sep 16 '21 at 14:04

8 Answers8

33

I'd recommend using py.test as well. py.test makes it easy to create re-usable SparkContext test fixtures and use it to write concise test functions. You can also specialize fixtures (to create a StreamingContext for example) and use one or more of them in your tests.

I wrote a blog post on Medium on this topic:

https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b

Here is a snippet from the post:

pytestmark = pytest.mark.usefixtures("spark_context")
def test_do_word_counts(spark_context):
    """ test word couting
    Args:
       spark_context: test fixture SparkContext
    """
    test_input = [
        ' hello spark ',
        ' hello again spark spark'
    ]

    input_rdd = spark_context.parallelize(test_input, 1)
    results = wordcount.do_word_counts(input_rdd)

    expected_results = {'hello':2, 'spark':3, 'again':1}  
    assert results == expected_results
Vikas Kawadia
  • 439
  • 4
  • 6
  • 8
    Welcome to SO! Primarily link answers are frowned on. (That is to say, answers which, were the link to disappear, would have no enduring worth.) It is recommended to add a bit of useful text summarizing or highlighting key points from the linked resource. – sclv Mar 18 '16 at 05:08
  • @Vikas Kawadia could you please have a look at `https://stackoverflow.com/questions/49420660/unit-test-pyspark-code-using-python` – User12345 Mar 22 '18 at 05:16
  • The RDD test outlined in the blog post is fine, but the DataFrame test only check that there are two rows of data. It doesn't verify that the DataFrame schemas and contents are the same, so it's not a robust test. See my answer for a better methodology for making DataFrame comparisons. – Powers Jul 09 '20 at 11:33
30

Here's a solution with pytest if you're using Spark 2.x and SparkSession. I'm also importing a third party package.

import logging

import pytest
from pyspark.sql import SparkSession

def quiet_py4j():
    """Suppress spark logging for the test context."""
    logger = logging.getLogger('py4j')
    logger.setLevel(logging.WARN)


@pytest.fixture(scope="session")
def spark_session(request):
    """Fixture for creating a spark context."""

    spark = (SparkSession
             .builder
             .master('local[2]')
             .config('spark.jars.packages', 'com.databricks:spark-avro_2.11:3.0.1')
             .appName('pytest-pyspark-local-testing')
             .enableHiveSupport()
             .getOrCreate())
    request.addfinalizer(lambda: spark.stop())

    quiet_py4j()
    return spark


def test_my_app(spark_session):
   ...

Note if using Python 3, I had to specify that as a PYSPARK_PYTHON environment variable:

import os
import sys

IS_PY2 = sys.version_info < (3,)

if not IS_PY2:
    os.environ['PYSPARK_PYTHON'] = 'python3'

Otherwise you get the error:

Exception: Python in worker has different version 2.7 than that in driver 3.5, PySpark cannot run with different minor versions.Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

Kamil Sindi
  • 21,782
  • 19
  • 96
  • 120
  • The avro plugin doesn't work when I use this code on Spark 2.0.2 – clay Mar 17 '17 at 18:59
  • 1
    The Avro plugin can be loaded like that with Spark 2.1, but not Spark 2.0.2. You won't get an error until you try to use the Avro format. I've tested this myself. – clay Mar 17 '17 at 21:48
  • 5
    A slightly easier way of setting the right value of PYSPARK_PYTHON: `os.environ['PYSPARK_PYTHON'] = sys.executable` -- this will set to what ever the current running python is, and will cope with venvs a bit better too hopefully – Ash Berlin-Taylor Feb 01 '18 at 10:44
  • @ksindi could you please have a look at `https://stackoverflow.com/questions/49420660/unit-test-pyspark-code-using-python` – User12345 Mar 22 '18 at 05:20
  • @user9367133 answered your question – Kamil Sindi Mar 22 '18 at 13:46
  • This answer is for old versions of Spark. You don't need to set any environment variables to use PySpark, just use pyspark from PyPi: https://pypi.org/project/pyspark/. Avro is built into Spark as of the 2.4 release. See my answer for a more modern approach. – Powers Jul 09 '20 at 11:41
  • very nice template for running tests including the `PYSPARK~ settings , `enableHiveSupport()` and `quietLog4j`. These seem to still be relevant – WestCoastProjects Dec 06 '22 at 06:44
22

Assuming you have pyspark installed, you can use the class below for unitTest it in unittest:

import unittest
import pyspark


class PySparkTestCase(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
        cls.sc = pyspark.SparkContext(conf=conf)
        cls.spark = pyspark.SQLContext(cls.sc)

    @classmethod
    def tearDownClass(cls):
        cls.sc.stop()

Example:

class SimpleTestCase(PySparkTestCase):

    def test_with_rdd(self):
        test_input = [
            ' hello spark ',
            ' hello again spark spark'
        ]

        input_rdd = self.sc.parallelize(test_input, 1)

        from operator import add

        results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
        self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])

    def test_with_df(self):
        df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']], 
                                        schema=['c1', 'c2'])
        self.assertEqual(df.count(), 2)

Note that this creates a context per class. Use setUp instead of setUpClass to get a context per test. This typically adds a lot of overhead time on the execution of the tests, as creating a new spark context is currently expensive.

Jorge Leitao
  • 19,085
  • 19
  • 85
  • 121
  • Hola @Jorge. What if I want to create an additional setUpClass in a new test class and I need to access the sparkSession from PySparkTestCase? I've tried calling super().setUpClass() and then accesing super().spark but that doesn't work. – itscarlayall Jun 14 '22 at 13:40
  • Nevermind, already solved it using cls.spark! – itscarlayall Jun 14 '22 at 13:54
11

I use pytest, which allows test fixtures so you can instantiate a pyspark context and inject it into all of your tests that require it. Something along the lines of

@pytest.fixture(scope="session",
                params=[pytest.mark.spark_local('local'),
                        pytest.mark.spark_yarn('yarn')])
def spark_context(request):
    if request.param == 'local':
        conf = (SparkConf()
                .setMaster("local[2]")
                .setAppName("pytest-pyspark-local-testing")
                )
    elif request.param == 'yarn':
        conf = (SparkConf()
                .setMaster("yarn-client")
                .setAppName("pytest-pyspark-yarn-testing")
                .set("spark.executor.memory", "1g")
                .set("spark.executor.instances", 2)
                )
    request.addfinalizer(lambda: sc.stop())

    sc = SparkContext(conf=conf)
    return sc

def my_test_that_requires_sc(spark_context):
    assert spark_context.textFile('/path/to/a/file').count() == 10

Then you can run the tests in local mode by calling py.test -m spark_local or in YARN with py.test -m spark_yarn. This has worked pretty well for me.

santon
  • 4,395
  • 1
  • 24
  • 43
  • could you please have a look at `https://stackoverflow.com/questions/49420660/unit-test-pyspark-code-using-python` – User12345 Mar 22 '18 at 05:19
10

You can test PySpark code by running your code on DataFrames in the test suite and comparing DataFrame column equality or equality of two entire DataFrames.

The quinn project has several examples.

Create SparkSession for test suite

Create a tests/conftest.py file with this fixture, so you can easily access the SparkSession in your tests.

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope='session')
def spark():
    return SparkSession.builder \
      .master("local") \
      .appName("chispa") \
      .getOrCreate()

Column equality example

Suppose you'd like to test the following function that removes all non-word characters from a string.

def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")

You can test this function with the assert_column_equality function that's defined in the chispa library.

def test_remove_non_word_characters(spark):
    data = [
        ("jo&&se", "jose"),
        ("**li**", "li"),
        ("#::luisa", "luisa"),
        (None, None)
    ]
    df = spark.createDataFrame(data, ["name", "expected_name"])\
        .withColumn("clean_name", remove_non_word_characters(F.col("name")))
    assert_column_equality(df, "clean_name", "expected_name")

DataFrame equality example

Some functions need to be tested by comparing entire DataFrames. Here's a function that sorts the columns in a DataFrame.

def sort_columns(df, sort_order):
    sorted_col_names = None
    if sort_order == "asc":
        sorted_col_names = sorted(df.columns)
    elif sort_order == "desc":
        sorted_col_names = sorted(df.columns, reverse=True)
    else:
        raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
            sort_order=sort_order
        ))
    return df.select(*sorted_col_names)

Here's one test you'd write for this function.

def test_sort_columns_asc(spark):
    source_data = [
        ("jose", "oak", "switch"),
        ("li", "redwood", "xbox"),
        ("luisa", "maple", "ps4"),
    ]
    source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])

    actual_df = T.sort_columns(source_df, "asc")

    expected_data = [
        ("switch", "jose", "oak"),
        ("xbox", "li", "redwood"),
        ("ps4", "luisa", "maple"),
    ]
    expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])

    assert_df_equality(actual_df, expected_df)

Testing I/O

It's generally best to abstract code logic from I/O functions, so they're easier to test.

Suppose you have a function like this.

def your_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.withColumn(...).transform(function1).transform(function2)
    df2.write.parquet("other directory")

It's better to refactor the code like this:

def all_logic(df):
  return df.withColumn(...).transform(function1).transform(function2)

def your_formerly_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.transform(all_logic)
    df2.write.parquet("other directory")

Designing your code like this lets you easily test the all_logic function with the column equality or DataFrame equality functions mentioned above. You can use mocking to test your_formerly_big_function. It's generally best to avoid I/O in test suites (but sometimes unavoidable).

Machavity
  • 30,841
  • 27
  • 92
  • 100
Powers
  • 18,150
  • 10
  • 103
  • 108
5

pyspark has unittest module which can be used as below

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase

class MySparkTests(PySparkTestCase):
    def spark_session(self):
        return pyspark.SQLContext(self.sc)

    def createMockDataFrame(self):
         self.spark_session().createDataFrame(
            [
                ("t1", "t2"),
                ("t1", "t2"),
                ("t1", "t2"),
            ],
            ['col1', 'col2']
        )
SaiNageswar S
  • 1,203
  • 13
  • 22
2

Sometime ago I've also faced the same issue and after reading through several articles, forums and some StackOverflow answers I've ended with writing a small plugin for pytest: pytest-spark

I am already using it for few months and the general workflow looks good on Linux:

  1. Install Apache Spark (setup JVM + unpack Spark's distribution to some directory)
  2. Install "pytest" + plugin "pytest-spark"
  3. Create "pytest.ini" in your project directory and specify Spark location there.
  4. Run your tests by pytest as usual.
  5. Optionally you can use fixture "spark_context" in your tests which is provided by plugin - it tries to minimize Spark's logs in the output.
Alex Markov
  • 301
  • 3
  • 7
1

Combining some of the other answers, this is what I found to work on pyspark 3.3 with fixtures (pytest) and TestCase (unittest). First set up a fixture for the spark session, which will later be called for all related tests. By using a fixture, we avoid the overhead of setting up each and every time we need the initialize a session. This is done in src/tests/conftest.py.

# src/tests/conftest.py

import pytest

from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark_session():
    spark = (
        SparkSession.builder.master("local[1]")  # run on local machine
        .appName("local-tests")
        .config("spark.executor.cores", "1")
        .config("spark.executor.instances", "1")
        .config("spark.sql.shuffle.partitions", "1")
        .config("spark.driver.bindAddress", "127.0.0.1")
        .getOrCreate()
    )
    yield spark
    spark.stop()

With the function:

# src/utils/spark_utils.py
from pyspark.sql import DataFrame

def my_spark_function(df: DataFrame) -> bool:
   ...

And test:

# src/tests/utils/test_spark_utils.py

from unittest import TestCase

import pytest
from utils.spark_utils import my_spark_function

columns_underscore = ["the", "watchtower"]
data = [("joker", 1), ("thief", 2), ("princes", 3)]


class TestMySparkFunction(TestCase):
    @pytest.fixture(autouse=True)
    def prepare_fixture(self, spark_session):
        self.spark_session = spark_session

    def test_function_okay(self):
        df = self.spark_session.createDataFrame(data=data, schema=columns)
        self.assertEqual(my_spark_function(df), True)

Finally the tests can be executed with pytest.

Eric Aya
  • 69,473
  • 35
  • 181
  • 253
Hugh
  • 11
  • 1