2

So I am currently plotting a scatter graph with many x and ys in matplotlib:

plt.scatter(x, y)

I want to draw a line on this scatter graph that crosses through the whole graph (i.e hits two 'borders') I know the gradient and the intercept - m and the c in the equation y = mx +c.

I have thought about acquiring the 4 points of the plot (calculating the min and max scatter x and ys) and from that calculating the min and max coords for the line and then plotting but that seems very convoluted. Is there any better way to do this bearing in mind the line may not even be 'within' the 'plot'?


Example of scatter graph: enter image description here

as identified visually in the plot the four bordering coordinates are ruffly:

  • bottom left: -1,-2
  • top left: -1,2
  • bottom right: 6,-2
  • top right 6,2

I now have a line that I need to plot that must not exceed these boundaries but if it enters the plot must touch two of the boundary points.

So I could check what y equals when x = -1 and then check if that value is between -1 and 6 and if it is the line must cross the left border, so plot it, and so on and so fourth.


Ideally though I would create a line from -infinity to infinity and then crop it to fit the plot.

maxisme
  • 3,974
  • 9
  • 47
  • 97

2 Answers2

8

The idea here is to draw a line of some equation y=m*x+y0 into the plot. This can be achieved by transforming a horizontal line, originally given in axes coordinates, into data coordinates, applying the Affine2D transform according to the line equation and transforming back to screen coordinates.

The advantage here is that you do not need to know the axes limits at all. You may also freely zoom or pan your plot; the line will always stay within the axes boundaries. It hence effectively implements a line ranging from -infinity to + inifinty.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms

def axaline(m,y0, ax=None, **kwargs):
    if not ax:
        ax = plt.gca()
    tr = mtransforms.BboxTransformTo(
            mtransforms.TransformedBbox(ax.viewLim, ax.transScale))  + \
         ax.transScale.inverted()
    aff = mtransforms.Affine2D.from_values(1,m,0,0,0,y0)
    trinv = ax.transData
    line = plt.Line2D([0,1],[0,0],transform=tr+aff+trinv, **kwargs)
    ax.add_line(line)

x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)

fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")

# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
axaline(m,y0, ax=ax, color="limegreen", linewidth=5)

plt.show()

enter image description here

While this solution looks a bit complicated on first sight, one does not need to fully understand it. Just copy the axaline function to your code and use it as it is.


In order to get the automatic updating working without the transforms doing this, one may add callbacks which would reset the transform every time something changes in the plot.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms

class axaline():
    def __init__(self, m,y0, ax=None, **kwargs):
        if not ax: ax = plt.gca()
        self.ax = ax
        self.aff = transforms.Affine2D.from_values(1,m,0,0,0,y0)
        self.line = plt.Line2D([0,1],[0,0], **kwargs)
        self.update()
        self.ax.add_line(self.line)
        self.ax.callbacks.connect('xlim_changed', self.update)
        self.ax.callbacks.connect('ylim_changed', self.update)

    def update(self, evt=None):
        tr = ax.transAxes - ax.transData
        trinv = ax.transData
        self.line.set_transform(tr+self.aff+trinv)

x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)

fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")

# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
al = axaline(m,y0, ax=ax, color="limegreen", linewidth=5)

plt.show()
ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
  • 2
    ...quite a sledgehammer...nice – mikuszefski Mar 09 '18 at 13:30
  • Well, one could beautify my quick and dirty version to be a function like your `axaline` including `**kwargs` etc. The pan effect is probably not so important. But where my method definitively can fail is upon plotting a third data set afterwards. In such a case your solution might come in handy. – mikuszefski Mar 09 '18 at 13:43
  • @mikuszefski In theory yes, in practice not, since it is not working as it should currently. There seems to be a problem with the inversion of transforms, which would require to plot this line *after* all other elements (which is of course undesired). Still working on it. – ImportanceOfBeingErnest Mar 09 '18 at 13:46
  • yep.. just tried this one, and unfortunately no, neither pan nor additional data...looking forward to the working solution. – mikuszefski Mar 09 '18 at 13:48
  • ...may I also mention that your slope `m` seems to be `0.5` and not `2`...and the intercept...? – mikuszefski Mar 09 '18 at 13:55
  • I opened an [issue](https://github.com/matplotlib/matplotlib/issues/10741) about it since the problem seems to be somewhere deep in matplotlib. – ImportanceOfBeingErnest Mar 09 '18 at 15:38
  • @mikuszefski I solved the issue with the transforms. I also added another solution which would update itself through callbacks, such that in both solutions new plots can be added and the plot can be zoomed etc. – ImportanceOfBeingErnest Mar 11 '18 at 18:05
  • I like the callback solution a lot, give quite some matplotlib insight. Nice! – mikuszefski Mar 14 '18 at 08:33
2

You may try:

import matplotlib.pyplot as plt
import numpy as np

m=3
c=-2
x1Data= np.random.normal(scale=2, loc=.4, size=25)
y1Data= np.random.normal(scale=3, loc=1.2, size=25)
x2Data= np.random.normal(scale=1, loc=3.4, size=25)
y2Data= np.random.normal(scale=.65, loc=-.2, size=25)

fig = plt.figure()
ax = fig.add_subplot( 1, 1, 1 )
ax.scatter(x1Data, y1Data)
ax.scatter(x2Data, y2Data)
ylim = ax.get_ylim()
xlim = ax.get_xlim()
ax.plot( xlim, [ m * x + c for x in xlim ], 'r:' )
ax.set_ylim( ylim )
ax.set_xlim( xlim )
plt.show()

which gives:

enter image description here

mikuszefski
  • 3,943
  • 1
  • 25
  • 38
  • shouldn't `xlim = ax.get_ylim()` be `xlim = ax.get_xlim()` ? – tmdavison Mar 09 '18 at 12:51
  • also, note that if you set `scalex=False, scaley=False` inside the `ax.plot` command, you don't need to re-set the x and y limits – tmdavison Mar 09 '18 at 12:55
  • @tom concerning the first comment: yes, I changed that...( explains why the image is actually square in data,...ups ) concerning the second comment: `ax.set_autoscale_on(False)` before the plot would work just fine – mikuszefski Mar 09 '18 at 13:22