# for colab
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/Othercomputers/xiaoxin/Mixed-Effect-Model-Numerical-Algorithm
from lmm_em import *
import time
# load data
data_name = 'XYZ_MoM'
data = pd.read_table('data/'+data_name+'.txt', sep='\t', header=0).values
y = data[:, 0].reshape(-1, 1)
Z = data[:, 1:31]
X = data[:, 31:]
start_time = time.time()
# run EM algorithm
likelihood_list, omega_list, sigma_beta2_list, sigma_e2_list, beta_post_mean = lmm_em(y, X, Z, max_iter=50)
end_time = time.time()
print('Run time: ', end_time - start_time, 's')
# run EM algorithm with limited data
# MAX_LENGTH = 200
# MAX_X_LENGTH = 100
# likelihood_list, omega_list, sigma_beta2_list, sigma_e2_list, beta_post_mean = lmm_em(
# y[:MAX_LENGTH], X[:MAX_LENGTH, :MAX_X_LENGTH], Z[:MAX_LENGTH, :])
# subplot
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# omit the first point
axes[0, 0].plot(range(1, len(likelihood_list)), likelihood_list[1:])
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_title('Likelihood')
axes[0, 1].plot(sigma_beta2_list, label=r'$\sigma_\beta^2$')
axes[0, 1].plot(sigma_e2_list, label=r'$\sigma_e^2$')
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_title('Unknown Variance')
axes[0, 1].legend()
axes[1, 0].plot(omega_list.T)
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_title(r'$\omega$')
# hist
sns.histplot(omega_list[:, -1], kde=True, bins=10, ax=axes[1, 1], label=r'$\omega$')
axes[1, 1].axvline(beta_post_mean,
color='r',
linestyle='--',
label=rf'$\beta={beta_post_mean:.4e}$')
axes[1, 1].set_title('Effects')
axes[1, 1].legend()
plt.suptitle('EM Algorithm for Linear Mixed Model')
plt.tight_layout()
plt.savefig('img/lmm_em'+data_name+'.png')
plt.show()