# -*- coding: utf-8 -*-
"""
Created on Mon Dec 27 11:48:28 2021

@author: Sim
"""
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import summary_table

df = pd.read_csv(r'D:\HTEX\Pythonbk\codesdata\WorldHappinessData2015s.csv')
df.sort_values(by='GDP')

y = df['happiness']
x = df['GDP']
X = sm.add_constant(x)

re = sm.OLS(y, X).fit()
dt, data, names = summary_table(re, alpha=0.05)

pred = data[:, 2]
predict_mean_ci_low, predict_mean_ci_upp = data[:, 4:6].T
predict_ci_low, predict_ci_upp = data[:, 6:8].T

#---------------------------------------------------------
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12, 4))
ax1.plot(x, y, '.')
ax1.plot(x, pred, 'k-', lw=2)
ax1.plot(x, predict_ci_low, 'g-.', lw=1)
ax1.plot(x, predict_ci_upp, 'g-.', lw=1)
ax1.plot(x, predict_mean_ci_low, 'b:', lw=1)
ax1.plot(x, predict_mean_ci_upp, 'b:', lw=1)

#---------------  99%
dt, data, names = summary_table(re, alpha=0.01)
pred = data[:, 2]
predict_mean_ci_low, predict_mean_ci_upp = data[:, 4:6].T
predict_ci_low, predict_ci_upp = data[:, 6:8].T

ax2.plot(x, y, '.')
ax2.plot(x, pred, 'k-', lw=2)
ax2.plot(x, predict_ci_low, 'g-.', lw=1)
ax2.plot(x, predict_ci_upp, 'g-.', lw=1)
ax2.plot(x, predict_mean_ci_low, 'b:', lw=1)
ax2.plot(x, predict_mean_ci_upp, 'b:', lw=1)

plt.show()