I have:
df = pd.DataFrame({"State": ["CA", "NY", "CA", "NY", "CA", "NY", "TX", "TX", "TX"],
"Company": ["A", "A", "A", "B", "C", "D", "A", "B", "B"],
"Profits": [3, 2, 5, 6, 7, 2, 2, 4, 7]})
State Company Profits
0 CA A 3
1 NY A 2
2 CA A 5
3 NY B 6
4 CA C 7
5 NY D 2
6 TX A 2
7 TX B 4
8 TX B 7
I would like to create a scatter plot with each point corresponding to a state. On the x-axis, I want the number of unique companies in that state (e.g. CA has 2 companies A and C). On the y-axis, I want the average profits of all companies in the state (e.g. California's average profit is 5).
I try:
n_companies = df.groupby("State")["Company"].nunique()
mean_profits = df.groupby("State")["Profits"].mean()
import matplotlib.pyplot as plt
plt.scatter(n_companies, mean_profits, label)
plt.show()
which appears to work but how do I get the label of each point to be its state?