# -*- coding: utf-8 -*-
"""
Created on Sat Nov 20 10:51:46 2021

@author: Sim
"""

#import pingouin as pg
import pandas as pd
import numpy as np


#pg.cronbach_alpha(data=df)

#---------- Cronbach alpha for DataFrame ------------
def c_alpha(df):
   p = df.shape[1]
   s = df.sum(axis=1)
   vars = np.var(s, ddof=1)
   varx = df.var()
   alpha = p/(p-1)*(1- np.sum(varx)/vars)
   return(alpha)

#--------- Cronbach alpha remove item ---------------
def rm_alpha(df):
   p = df.shape[1]
   a_rm = []
   colnames = df.columns
   for i in range(p):
      dfrm = df.drop(columns=colnames[i])
      s = dfrm.sum(axis=1)
      vars = np.var(s, ddof=1)
      varx = dfrm.var()
      alpha = (p-1)/(p-2)*(1- np.sum(varx)/vars)
      a_rm.append(alpha)
   return(a_rm)
  
#----------------- 표준화 버전 alpha -----------------
def z_alpha(df):
  from sklearn.preprocessing import StandardScaler

  # Standardization 평균 0 / 분산 1
  scaler = StandardScaler()   
  zdf = scaler.fit_transform(df)
  return( c_alpha(pd.DataFrame(zdf)) )  
    
#------------- 표준화 버전 직접 계산 -----------
def z_alpha2(df):
   corx = df.corr()
   p = corx.shape[0]
   sum_cor = (corx.sum().sum() - p)/2 # 모든 상관계수의 합
   m_cor = sum_cor/(p*(p-1)/2)
   z_alpha = (p*m_cor) /(1+(p-1)*m_cor)
   return(z_alpha)

#----------------- 표준화 버전 remove item alpha -----------------
def z_rm_alpha(df):
  from sklearn.preprocessing import StandardScaler

  # Standardization 평균 0 / 분산 1
  scaler = StandardScaler()   
  zdf = scaler.fit_transform(df)
  return( rm_alpha(pd.DataFrame(zdf)) )

############## calls 

if __name__ == '__main__':

  df = pd.DataFrame({'x1': [3, 2, 2, 1, 2, 2, 3, 3, 2, 3],
                     'x2': [3, 1, 1, 1, 3, 3, 2, 3, 3, 3],
                     'x3': [1, 2, 1, 1, 2, 3, 3, 3, 2, 3]})
  print('Cronbach alpha: ', c_alpha(df))
  print('Cronbach alpha remove item: ', rm_alpha(df))
  print('standardized Cronbach alpha: ', z_alpha(df))
  print('standardized Cronbach alpha-2: ', z_alpha2(df))
  print('standardized Cronbach alpha remove item: ', z_rm_alpha(df))
