6

Is there a function in matplotlib similar to MATLAB's line extensions?

I am basically looking for a way to extend a line segment to a plot. My current plot looks like this.

After looking at another question and applying the formula, I was able to get it to here, but it still looks messy.

Does anyone have the magic formula here?

tdy
  • 36,675
  • 19
  • 86
  • 83
user2869668
  • 79
  • 1
  • 1
  • 2
  • 2
    Please leave a comment before downvoting. How is this question not useful to the community? – user2869668 Oct 11 '13 at 05:50
  • plt.axline lets you draw infinite lines. See https://matplotlib.org/devdocs/gallery/pyplots/axline.html and https://stackoverflow.com/a/64213625/4542084 – Burrito Dec 24 '21 at 23:06

2 Answers2

4

Have a go to write your own as I don't think this exists in matplotlib. This is a start, you could improve by adding the semiinfinite etc

import matplotlib.pylab as plt
import numpy as np

def extended(ax, x, y, **args):

    xlim = ax.get_xlim()
    ylim = ax.get_ylim()

    x_ext = np.linspace(xlim[0], xlim[1], 100)
    p = np.polyfit(x, y , deg=1)
    y_ext = np.poly1d(p)(x_ext)
    ax.plot(x_ext, y_ext, **args)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    return ax


ax = plt.subplot(111)
ax.scatter(np.linspace(0, 1, 100), np.random.random(100))

x_short = np.linspace(0.2, 0.7)
y_short = 0.2* x_short

ax = extended(ax, x_short, y_short,  color="r", lw=2, label="extended")
ax.plot(x_short, y_short, color="g", lw=4, label="short")

ax.legend()
plt.show()

enter image description here

I just realised you have some red dots on your plots, are those important? Anyway the main point I think you solution so far is missing is to set the plot limits to those that existed before otherwise, as you have found, they get extended.

Greg
  • 11,654
  • 3
  • 44
  • 50
1

New in matplotlib 3.3

There is now an axline method to easily extend arbitrary lines:

Adds an infinitely long straight line. The line can be defined either by two points xy1 and xy2

plt.axline(xy1=(0, 1), xy2=(1, 0.5), color='r')

or defined by one point xy1 and a slope.

plt.axline(xy1=(0, 1), slope=-0.5, color='r')


Sample data for reference:

import numpy as np
import matplotlib.pyplot as plt

x, y = np.random.default_rng(123).random((2, 100)) * 2 - 1
m, b = -0.5, 1
plt.scatter(x, y, c=np.where(y > m*x + b, 'r', 'k'))
tdy
  • 36,675
  • 19
  • 86
  • 83