import pandas as pd
from scipy import stats
import numpy as np
import utils

# each of the gmat_ctrl, gmat_case have mean, z, std, and number inside, as generated by gmat_cal
# sample: 0 is ctrl, 1 is gene
def gene_diff_cal_clust(gmat_case,gmat_ctrl,samples,genes):
    #rescale = np.log(10)/np.log(2)

    # mean diff
    diff_all=[]

    # diff1 = (gmat_mean.query(query_case).droplevel(level=1) - gmat_mean.query(query_ctrl).droplevel(level=1)).loc[celltypes]*rescale
    diff1 = (gmat_case[0].values-gmat_ctrl[0].values)
    diff1 = pd.DataFrame(diff1, columns=genes)
    diff1['slice'] = samples[0]+"-"+samples[1] # 0: case; 1: ctrl

    diffmat = diff1

    # significance diff
    alpha = 0.05
    rej_all=[]
    q_all=[]

    # t, p = stats.ttest_ind_from_stats(
    #     gmat_mean.query(query_ctrl).loc[celltypes].values,
    #     gmat_std.query(query_ctrl).loc[celltypes].values,
    #     gmat_n.query(query_ctrl).loc[celltypes].values,
    #     gmat_mean.query(query_case).loc[celltypes].values,
    #     gmat_std.query(query_case).loc[celltypes].values,
    #     gmat_n.query(query_case).loc[celltypes].values,
    # )
    t, p = stats.ttest_ind_from_stats(
        gmat_ctrl[0].values,
        gmat_ctrl[2].values,
        gmat_ctrl[3].values,
        gmat_case[0].values,
        gmat_case[2].values,
        gmat_case[3].values,
    )
    p[np.isnan(p)] = 1
    q = utils.get_fdr(p.reshape(-1)).reshape(p.shape)
    q = pd.DataFrame(q, columns=genes)
    q['slice'] = samples[0]+"-"+samples[1]

    rej = q[genes]<alpha

    q_all.append(q)
    rej_all.append(rej)

    qmat = pd.concat(q_all).sort_index(level=0)
    rejmat = pd.concat(rej_all).sort_index(level=0)

    return diffmat,qmat,rejmat