2

I want to write a unit test (using pytest) for a function which creates a plot of matplotlib but returns None.

Let's say the function show_plot would look like this:

import matplotlib.pyplot as plt

def show_plot():

    # create plot
    plt.plot([1, 2, 3], [4, 5, 3])

    # return None
    return None

When you call the function show_plot() you would see the created plot, but the plot object is not returned.

How can I write a unit test, to test that my function show_plot is plotting the correct plot? or at least checking that my function is indeed plotting something?

EDIT: I can't change or adjust my function show_plot()!

I need something like this:

def test_show_plot():
    # run show_plot
    show_plot()

    # Need help here!
    # ...
    # define plot_created
    # ...

    # logical value of plot_created, which indicates if a plot was
    # indeed created
    assert plot_created

For example I found here an interesting approach for stdout, and I hope there is something similar to capture plots.

Ferand Dalatieh
  • 313
  • 1
  • 4
  • 14

2 Answers2

1

You want to test that you're using the library in the way you expect.

First you have a dependency on plt. So let's rewrite the function a little.

import matplotlib.pyplot as plt

def show_plot(plt=plt):    
    plt.plot([1, 2, 3], [4, 5, 3])

This allows you to inject a stub so you can test it.

from unittest import mock
def test_show_plot():
    mock_plt = mock.MagicMock()
    show_plot(mock_plt)
    mock_plt.plot.assert_called_once_with([1, 2, 3], [4, 5, 3])

But how do you know this actually creates the plot? Well, try that same call with the real library on the shell and see for yourself that it works.

If you are unable to change the original function, see mock.patch

# plot.py
import matplotlib.pyplot as plt

def show_plot():    
    plt.plot([1, 2, 3], [4, 5, 3])


# test.py
from unittest import mock

@mock.patch('path.to.your.module.plt')
def test_show_plot(mock_plt):
    show_plot()
    mock_plt.plot.assert_called_once_with([1, 2, 3], [4, 5, 3])
Ferand Dalatieh
  • 313
  • 1
  • 4
  • 14
munk
  • 12,340
  • 8
  • 51
  • 71
  • Thank you! this looks interesting! but I can't change or adjust my function show_plot() – Ferand Dalatieh Mar 13 '19 at 14:40
  • Thank you for that too! Is there an approach for pytest? – Ferand Dalatieh Mar 13 '19 at 14:49
  • Can you clarify what you mean by "is there an approach for pytest"? `unittest.mock` is part of the standard library and should work with pytest. – munk Mar 13 '19 at 14:51
  • Sorry I was confused and thought this solution is not compatible with pytest. But now I tried it out and it is working. But unfortunately I still need to add the parameter plt=plt to show_plot() in order to use this approach. Is there another way without adjusting anything in show_function()? Thanks! – Ferand Dalatieh Mar 13 '19 at 15:07
  • I found a way to make it working :) I used your solution but with show_plot() instead of show_plot(mock_plt) in the test and kept my function unchanged, and it worked! :) Could you please edit your solution and remove the added argument, because i want to mark your answer as "accepted" to my question. – Ferand Dalatieh Mar 13 '19 at 15:41
  • 1
    Oops! I missed that last one. Thanks for the edit @FerandDalatieh! – munk Mar 13 '19 at 15:42
1

The question is really what you want to test.

  • If you want to test that that function works in the sense of "it runs and does not produce any error", just calling that function is enough

    def test_show_plot():
        show_plot()
    
  • If you want to test that the figure it produces is drawable and that drawing it does not produce any error,

    def test_show_plot():
        show_plot()
        plt.gcf().canvas.draw()
    
  • If you want to test that the line has the correct data associated with it, you can get the line from the current axes,

    import numpy as np
    import matplotlib.pyplot as plt
    from numpy.testing import assert_array_almost_equal
    
    def show_plot():
        plt.plot([1, 2, 3], [4, 5, 3])
        return None
    
    def test_show_plot():
        show_plot()
        line = plt.gca().get_lines()[0]
        assert_array_almost_equal(line.get_data(), [[1, 2, 3], [4, 5, 3]])
    
    test_show_plot()
    
  • Finally, there is a complete framework available for image comparison tests, though this would require to be run through pytest and might have some caveats when being run externally.

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712