The key thing to point out here is that the result of sns.pairplot(...)
is not an axis object. Instead it is a PairGrid object. Also I noticed that putting the x=y line after setting log scale caused problems, but before setting log scale worked.
Also, as a side note, you probably are interested in sns.scatterplot
or sns.regplot
both of which return an axis object which might be more appropriate for the case where you are plotting a single x,y scatterplot rather than a scatterplot matrix, which is more of what sns.pairgrid
is for.
I do not have your data handy, so this answer will not be a complete copy/paste answer for you, but below is a gist of what you will want to do if you want to accomplish this with a pairplot that works for a similar figure using the penguins dataset.
import seaborn as sns
import matplotlib.pyplot as plt
penguins = sns.load_dataset("penguins")
# Modify data so x=y line looks interesting in this particular example and works in log space
penguins['bill_length_mm_center'] = penguins["bill_length_mm"] - penguins["bill_length_mm"].mean()
penguins['bill_length_mm_center'] += -penguins["bill_length_mm_center"].min() + 1
penguins['bill_depth_mm_center'] = penguins["bill_depth_mm"] - penguins["bill_depth_mm"].mean()
penguins['bill_depth_mm_center'] += -penguins["bill_depth_mm_center"].min() + 1
g = sns.pairplot(x_vars=['bill_length_mm_center'], y_vars=['bill_depth_mm_center'], hue="species", data=penguins)
def modify_plot(*args, **kwargs):
"""Must take x, y arrays as positional arguments and draw onto the “currently active”
matplotlib Axes. Also needs to accept kwargs called color and label.
We are not using any of these args in this example so just capture them all.
"""
# The "currently active" matplotlib Axis, unless they decide to pass it to us
if "ax" in kwargs:
# Not sure if this is ever used...
ax = kwargs['ax']
else:
ax = plt.gca()
# Make sure to ax.plot prior to ax.set, for some reason it doesn't work after
# Oncematplotlib.__version__ >= 3.3 do the following (https://stackoverflow.com/a/73490857/658053)
# ax.axline((0, 0), slope=1)
# but the following is similar for earlier versions of matplotlib (https://stackoverflow.com/a/60950862/658053)
xpoints = ypoints = ax.get_xlim()
ax.plot(xpoints, ypoints, linestyle='--', color='k', lw=1, scalex=False, scaley=False)
ax.set(xscale="log", yscale="log")
g.map(modify_plot);
