3

I cannot seem to figure out how to get the linear regression line (aka line of best fit) to span the entire width of the graph. It seems to just go up the the furthest data point on the left and the furthest data point on the right, and no further. How would I fix this?

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb

# connect to MySQL database
def mysql_select_all():
    conn = MySQLdb.connect(host='localhost',
                           user='root',
                           passwd='XXXXX',
                           db='world')
    cursor = conn.cursor()
    sql = """
        SELECT
            GNP, Population
        FROM
            country
        WHERE
            Name LIKE 'United States'
                OR Name LIKE 'Canada'
                OR Name LIKE 'United Kingdom'
                OR Name LIKE 'Russia'
                OR Name LIKE 'Germany'
                OR Name LIKE 'Poland'
                OR Name LIKE 'Italy'
                OR Name LIKE 'China'
                OR Name LIKE 'India'
                OR Name LIKE 'Japan'
                OR Name LIKE 'Brazil';
    """

    cursor.execute(sql)
    result = cursor.fetchall()

    list_x = []
    list_y = []

    for row in result:
        list_x.append(('%r' % (row[0],)))

    for row in result:
        list_y.append(('%r' % (row[1],)))

    list_x = list(map(float, list_x))
    list_y = list(map(float, list_y))

    fig = plt.figure()
    ax1 = plt.subplot2grid((1,1), (0,0))

    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression

    ax1.xaxis.labelpad = 50
    ax1.yaxis.labelpad = 50

    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression  
    plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
    plt.xlabel("GNP (US dollars)", fontsize=30)
    plt.ylabel("Population(in billions)", fontsize=30)
    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 
                7000000, 8000000, 9000000],  rotation=45, fontsize=14)
    plt.yticks(fontsize=14)

    plt.show()
    cursor.close()

mysql_select_all()
Nick T
  • 643
  • 2
  • 9
  • 22

3 Answers3

1

MySQLdb is not installed on my system so I can't run your code as such but the following lines of code should definitely work.

EDIT based on the comments: You additionally have to set the x-limits

x_low = 0.9*min(list_x)
x_high = 1.1*max(list_x)
x_extended = np.linspace(x_low, x_high, 100)

p1 = np.polyfit(list_x, list_y, 1)             # this line refers to line of regression

ax1.xaxis.labelpad = 50
ax1.yaxis.labelpad = 50

plt.plot(x_extended, np.polyval(p1,x_extended),'r-')   # this line refers to line of regression
plt.xlim(x_low, h_high)
Sheldore
  • 37,862
  • 7
  • 57
  • 71
1

Since you didn't include the data, here is a simple example using some artificial data. The idea here is to find what the value of the regression line would be at the x-limits of your plot, and then force matplotlib not to add the normal 'buffer' at the edges of the data.

import numpy as np
import matplotlib.pyplot as plt

x = [1, 1.8, 3.3, 3.5, 5.5, 6.1]
y = [1, 2.1, 3.0, 3.7, 5.2, 6.4]

p1 = np.polyfit(x, y, 1)

plt.scatter(x, y)
xlims = plt.xlim()
x.insert(0, xlims[0])
y.insert(0, np.polyval(p1, xlims[0]))
x.append(xlims[1])
y.append(np.polyval(p1, xlims[1]))
plt.plot(x, np.polyval(p1,x), 'r-', linewidth = 1.5)
plt.xlim(xlims)
plt.show()

Without extending the regression line the sample data looks like Before extending regression line

And after extending,

After extending regression line

Ghost
  • 249
  • 1
  • 14
0

If you want your plot to not extend past your data in the x-axis, simply do the following:

fig, ax = plt.subplots()
ax.margins(x=0)
# Don't use plt.plot
ax.plot(list_x, np.polyval(p1,list_x),'r-')
ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)
ax.set_xlabel("GNP (US dollars)", fontsize=30)
ax.set_ylabel("Population(in billions)", fontsize=30)
ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000],  rotation=45, fontsize=14)
ax.tick_params(axis='y', labelsize=14)
PMende
  • 5,171
  • 2
  • 19
  • 26