2

I have a dataframe with 6 columns - id, jan_data, feb_data, mar_data, apr_data, group. For the sample data I have given 2 groups and I would like to create 2 different scatter plots - one for each group(groups can be more). Label should read as "group 1", "group 2".

X-axis should contain the columns names jan_data, feb_data, mar_data, apr_data with label as "months" and the y-axis should have name as "value"(written vertically). Now scatter plot has to be plotted for each id(for the corresponding months data in x-axis).

import pandas as pd
df_plot = pd.DataFrame({'id': [101,102,103,104,105,106],
                         'jan_data': [30,0,5000,5500,8900,80],
                         'feb_data': [40,0,6000,6780,7800,90],
                         'mar_data': [50,20,7000,4300,6700,78],
                         'apr_data': [60,30,8000,1200,0,67],
                         'group': [2,2,1,1,1,2]})
    id  jan_data    feb_data    mar_data    apr_data    group
0   101 30          40          50          60          2
1   102 0           0           20          30          2
2   103 5000        6000        7000        8000        1
3   104 5500        6780        4300        1200        1
4   105 8900        7800        6700        0           1
5   106 80          90          78          67          2

Can someone please help me with this. Below is sample image for one group - group 1.

Sample plot for group 1

1 Answers1

4

First melt your DataFrame so you have a single column for the X-values and a single column for the Y-values. Then use groupby to create a different plot for each group. We can specify that the id will be used for color. (Use a different cmap if you expect > 20 ids in a plot). DataFrame.plot.scatter forces a colorbar, so instead just use plt.scatter.

import matplotlib.pyplot as plt

data = df_plot.melt(id_vars=['id', 'group'])

for idx, gp in data.groupby('group'):
    fig, ax = plt.subplots(figsize=(4,4))
    ax.scatter(x=gp['variable'], y=gp['value'],
               c=gp['id'], cmap='tab20')
    
    ax.grid()
    ax.set_ylabel('Value', fontsize=12)
    ax.set_xlabel('Months', fontsize=12)
    ax.set_title(f'Group: {idx}', fontsize=14)
    
    plt.show()

enter image description here enter image description here

ALollz
  • 57,915
  • 7
  • 66
  • 89