I cannot seem to figure out how to get the linear regression line (aka line of best fit) to span the entire width of the graph. It seems to just go up the the furthest data point on the left and the furthest data point on the right, and no further. How would I fix this?
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb
# connect to MySQL database
def mysql_select_all():
conn = MySQLdb.connect(host='localhost',
user='root',
passwd='XXXXX',
db='world')
cursor = conn.cursor()
sql = """
SELECT
GNP, Population
FROM
country
WHERE
Name LIKE 'United States'
OR Name LIKE 'Canada'
OR Name LIKE 'United Kingdom'
OR Name LIKE 'Russia'
OR Name LIKE 'Germany'
OR Name LIKE 'Poland'
OR Name LIKE 'Italy'
OR Name LIKE 'China'
OR Name LIKE 'India'
OR Name LIKE 'Japan'
OR Name LIKE 'Brazil';
"""
cursor.execute(sql)
result = cursor.fetchall()
list_x = []
list_y = []
for row in result:
list_x.append(('%r' % (row[0],)))
for row in result:
list_y.append(('%r' % (row[1],)))
list_x = list(map(float, list_x))
list_y = list(map(float, list_y))
fig = plt.figure()
ax1 = plt.subplot2grid((1,1), (0,0))
p1 = np.polyfit(list_x, list_y, 1) # this line refers to line of regression
ax1.xaxis.labelpad = 50
ax1.yaxis.labelpad = 50
plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression
plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
plt.xlabel("GNP (US dollars)", fontsize=30)
plt.ylabel("Population(in billions)", fontsize=30)
plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,
7000000, 8000000, 9000000], rotation=45, fontsize=14)
plt.yticks(fontsize=14)
plt.show()
cursor.close()
mysql_select_all()