0

This problem is from my training class where I can add code only in method def draw_scatterplot(df). Using Anaconda Spyder, Python 3.8.3, Seaborn 0.10.1, Matplotlib 3.1.3. How can I return a plot with axes and data from my function def draw_scatterplot(df)?

import pandas as pd
import matplotlib 
matplotlib.use('Agg') 
import seaborn as sns 
import pickle 

def draw_scatterplot(df): 
    '''
    Returns a scatter plot.  
    '''
    # Create a scatter plot using Seaborn showing trend of A with B
    # for C.  Set the plot size to 10 inches in width and 2 inches 
    # in height respectively.

    # add your code below
    fig, ax1 = matplotlib.pyplot.subplots(figsize=(10,2))
    ax2 = sns.scatterplot(x='A', y='B', data=df, ax=ax1, hue='C')
    return fig

def serialize_plot(plot, plot_dump_file): 
    with open(plot_dump_file, mode='w+b') as fp: 
        pickle.dump(plot, fp) 

def main(): 
    df = pd.DataFrame(...) 
    plot2 = draw_scatterplot(df) 
    serialize_plot(plot2.axes, "plot2_axes.pk") 
    serialize_plot(plot2.data, "plot2_data.pk") 


> Error: Traceback (most recent call last):
> 
>   File "myscatterplot.py", line 265, in <module>
>     main()
> 
>   File "myscatterplot.py", line 255, in main
>     serialize_plot(plot2.data, "plot2_data.pk")
> 
> AttributeError: 'Figure' object has no attribute 'data'

Also I tried returning axes:

def draw_scatterplot(df): 
    '''
    Returns a scatter plot
    '''
    fig, ax1 = matplotlib.pyplot.subplots(figsize=(10,2))
    ax2 = sns.scatterplot(x='A', y='B', data=df, ax=ax1, hue='C')
    return ax2

Error:
AttributeError: 'AxesSubplot' object has no attribute 'data'

for both returning figure and axes, the serialize_plot(plot2.axes, "plot2_axes.pk") is working, as axes is returned from function and I see file "plot2_axes.pk" is created.

vin
  • 1
  • 1
  • 5

2 Answers2

0

To return the entire chart from a function, you can return your fig variable. It consists all information needed.

import pandas as pd
import matplotlib 
import seaborn as sns 
import pickle 

def draw_scatterplot(df): 
    '''
    Returns a scatter plot
    '''
    fig, ax1 = matplotlib.pyplot.subplots(figsize=(10,2))
    ax2 = sns.scatterplot(x='A', y='B', data=df, ax=ax1, hue='C')
#     return ax2
    return fig

def serialize_plot(plot, plot_dump_file): 
    with open(plot_dump_file, mode='w+b') as fp: 
        pickle.dump(plot, fp) 

def main(): 
    df = pd.DataFrame({"A":[1,2,3], "B":[6,2,7], "C":[1,0,1]}) 
    plot2 = draw_scatterplot(df) 

main()

(I'm using juypter notebook. Hence the call to main and no plot2.show)

Output:

output example

I understand that eventually you want to dump your figure into a pickle. For that you can dump directly plot2 (fig), no need for plot2.data or something like that.

def main(): 
    df = pd.DataFrame(...) 
    plot2 = draw_scatterplot(df) 
    serialize_plot(plot2, "plot2.pk")
Roim
  • 2,986
  • 2
  • 10
  • 25
  • Thanks Roim for your answer! I am trying to solve this problem as given in my training class. I can only change code in the method def draw_scatterplot(df). – vin Jul 23 '20 at 18:40
  • you mean that `serialize_plot(plot2.data, "plot2_data.pk")` will be called anyway? If so, please edit your question to clarify that – Roim Jul 23 '20 at 18:51
  • Thanks Roim. I have edited my question above. You are right ```serialize_plot(plot2.data, "plot2_data.pk")``` will be called anyway. – vin Jul 23 '20 at 19:25
  • @vin I'm a bit confused. Please give more details: do you have to use seaborn? does it have to be scatter plot? What exactly the data which you dump needs to be? if it's part of exercise, then understanding the broader scope may help. There is no 'data' attribute in seaborn's scatterplot – Roim Jul 23 '20 at 21:12
  • I'm confused by this problem too. As per the problem, need to create scatter plot using Seaborn. I have tried `sns.regplot` also. I see post to extract data `get_offsets()` https://stackoverflow.com/a/27852570. I am missing something... – vin Jul 23 '20 at 22:30
0

I updated the method as below and not getting the error now.

def draw_scatterplot(df): 
'''
Returns a scatter plot
'''
fig, ax1 = matplotlib.pyplot.subplots(figsize=(10,2))
ax2 = sns.scatterplot(x='A', y='B', data=df, ax=ax1, hue='C')
# return ax2
fig.data = df
return fig
vin
  • 1
  • 1
  • 5