I plotted a simple RelPlot
using Seaborn, which returned me a Facetgrid
object.
The code I used is the following:
import seaborn as sns
palette = sns.color_palette("rocket_r")
g1 = sns.relplot(
data=df,
x="number_of_weeks", y="avg_streams",
hue="state", col="state",
kind="line", palette=palette,
height=5, aspect=1, facet_kws=dict(sharex=False), col_wrap=3,
)
Which shows the following plot:
And the Dataframe I'm using has the following structure:
avg_streams date year number_of_weeks state
4 0.104011 31-01 2020 4 it
5 1.211951 07-02 2020 5 it
6 0.559374 14-02 2020 6 it
7 0.304257 21-02 2020 7 it
8 0.199218 28-02 2020 8 it
... ... ... ... ... ...
175 -0.938890 26-06 2020 25 br
176 -0.483821 03-07 2020 26 br
177 -0.083704 10-07 2020 27 br
178 0.165312 17-07 2020 28 br
179 0.218601 24-07 2020 29 br
I would like to add other lineplots to the single subplot. My final goal is to plot all the lines in each single subplot, but highlight a different state for each different subplot.
So I would like to get something like this for each different subplot in my Facetgrid
:
And this is the code I wrote for the previous plot:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
palette = {c:'red' if c=='it' else 'grey' for c in df.state.unique()}
fig, ax = plt.subplots(figsize=(15, 7))
plot=sns.lineplot(ax=ax,x="number_of_weeks", y="avg_streams", hue="state", data=df, palette=palette)
lines = ax.get_lines()
lines[0].set_linewidth(5)
plot.set(title='Streams trend')
But I can't figure out how to "merge" the 2 plots. How can I achieve my goal?
EDIT: I tried to add the plot "manually" selecting the single axes of my Facetgrid
. I followed this question: Add lineplot to subplot and I was able to add a simple line.
This is my try, I tried to add a simple line to my already existing plot:
palette = sns.color_palette("rocket_r")
g1 = sns.relplot(
data=df,
x="number_of_weeks", y="avg_streams",
hue="state", col="state",
kind="line", palette=palette,
height=5, aspect=1, facet_kws=dict(sharex=False), col_wrap=3,
)
axes = g1.fig.axes
print(axes)
axes[0].plot([20, 30], [40, 50], 'k-')