import pandas as pd
from scipy import stats
import numpy as np
import utils

def gene_diff_cal(gmat_mean,samples,genes,diffQuery):
    #rescale = np.log(10)/np.log(2)

    # mean diff
    diff_all=[]

    ctrl, case = samples[diffQuery[0]], samples[diffQuery[1]]

    query_ctrl = 'sample == "{}"'.format(ctrl)
    query_case = 'sample == "{}"'.format(case)
    # diff1 = (gmat_mean.query(query_case).droplevel(level=1) - gmat_mean.query(query_ctrl).droplevel(level=1)).loc[celltypes]*rescale
    diff1 = (gmat_mean[0].query(query_case).values - gmat_mean[0].query(query_ctrl).values)
    diff1 = pd.DataFrame(diff1, columns=genes)
    diff1['slice'] = samples[diffQuery[1]]+"-"+samples[diffQuery[0]]
    diff1 = diff1.set_index('slice')
    diff_all.append(diff1)

    diffmat = pd.concat(diff_all).sort_index(level=0)

    # significance diff
    alpha = 0.05
    rej_all=[]
    q_all=[]

        
    ctrl, case = samples[diffQuery[0]], samples[diffQuery[1]]

    query_ctrl = 'sample == "{}"'.format(ctrl)
    query_case = 'sample == "{}"'.format(case)
    # t, p = stats.ttest_ind_from_stats(
    #     gmat_mean.query(query_ctrl).loc[celltypes].values,
    #     gmat_std.query(query_ctrl).loc[celltypes].values,
    #     gmat_n.query(query_ctrl).loc[celltypes].values,
    #     gmat_mean.query(query_case).loc[celltypes].values,
    #     gmat_std.query(query_case).loc[celltypes].values,
    #     gmat_n.query(query_case).loc[celltypes].values,
    # )
    t, p = stats.ttest_ind_from_stats(
        gmat_mean[0].query(query_ctrl).values,
        gmat_mean[2].query(query_ctrl).values,
        gmat_mean[3].query(query_ctrl).values,
        gmat_mean[0].query(query_case).values,
        gmat_mean[2].query(query_case).values,
        gmat_mean[3].query(query_case).values,
    )
    p[np.isnan(p)] = 1
    q = utils.get_fdr(p.reshape(-1)).reshape(p.shape)
    q = pd.DataFrame(q, columns=genes)
    q['slice'] = samples[diffQuery[1]]+"-"+samples[diffQuery[0]]
    q = q.set_index('slice')
    rej = q<alpha

    q_all.append(q)
    rej_all.append(rej)

    qmat = pd.concat(q_all).sort_index(level=0)
    rejmat = pd.concat(rej_all).sort_index(level=0)

    return diffmat,qmat,rejmat