Have a go to write your own as I don't think this exists in matplotlib. This is a start, you could improve by adding the semiinfinite etc
import matplotlib.pylab as plt
import numpy as np
def extended(ax, x, y, **args):
xlim = ax.get_xlim()
ylim = ax.get_ylim()
x_ext = np.linspace(xlim[0], xlim[1], 100)
p = np.polyfit(x, y , deg=1)
y_ext = np.poly1d(p)(x_ext)
ax.plot(x_ext, y_ext, **args)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
return ax
ax = plt.subplot(111)
ax.scatter(np.linspace(0, 1, 100), np.random.random(100))
x_short = np.linspace(0.2, 0.7)
y_short = 0.2* x_short
ax = extended(ax, x_short, y_short, color="r", lw=2, label="extended")
ax.plot(x_short, y_short, color="g", lw=4, label="short")
ax.legend()
plt.show()

I just realised you have some red dots on your plots, are those important? Anyway the main point I think you solution so far is missing is to set the plot limits to those that existed before otherwise, as you have found, they get extended.