0

I am trying to add a horizontal line in a scatter plot based on a column of the dataframe - i got the following error: ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().


x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})

for country in datLong.ctr1.unique():
    temp_df = plt.figure(country)
    temp_df = datLong[datLong.ctr1 == country]
    ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label = 'xx', linewidth =3, alpha = 0.7, figsize=(7,4))    
   
    plt.title(country)
    plt.axvline(x=x_line) ### this is the line that is causing this error
 
    plt.show()
print (ax1)

The problem seems to be related to the dataframe. But I can figure out what it is? can anybody help me

Ana
  • 27
  • 4
  • 1
    `plt.axvline` requires that you pass it a number, but `x_line` is a dataframe. – Swier Mar 31 '21 at 09:11
  • thanks, @Swier, I have tried to convert into .numeric but it didn't work (Error: >' not supported between instances of 'float' and 'function'). In this case, I need to have several values (each one corresponding to one country) and the idea is to add the line for each country plot. – Ana Mar 31 '21 at 09:54

1 Answers1

1

x_line contains the values for all the countries. With x_line.loc[country] you'd get the value for that country. Because it returns an array (of just one element), and axvline only accepts single values, you can select its first element (x_line.loc[country][0]).

Note that plt.figure creates a figure, and pandas plot without the ax= parameter also creates a new figure. So, either you should leave out plt.figure(), or explicitly create an ax to draw on.

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

datLong = pd.DataFrame({'ctr1': np.repeat(['country 1', 'country 2'], 20),
                        'x': np.tile(np.arange(20), 2),
                        'maxx': np.random.randn(40) + 10,
                        'Price': np.random.randn(40) * 10 + 200})

x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})

for country in datLong.ctr1.unique():
    temp_df = datLong[datLong.ctr1 == country]
    ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
                       figsize=(7, 4))
    ax1.figure.canvas.set_window_title(country)
    ax1.set_title(country)
    ax1.axvline(x=x_line.loc[country][0])
    plt.show()

As groupby already creates the dataframes per country, you could rewrite the code making use of groupby (without needing x_line):

for country, country_df in datLong.groupby('ctr1'):
    ax1 = country_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
                       figsize=(7, 4))
    ax1.figure.canvas.set_window_title(country)
    ax1.set_title(country)
    ax1.axvline(x=country_df['maxx'].mean())
    plt.show()
JohanC
  • 71,591
  • 8
  • 33
  • 66