import sys, gzip, re, string, random
import numpy as np
from itertools import permutations
from itertools import combinations

import pprint as pp

#usage
## python3 cal_div.py alufile [idfile] > out
## if idfile is specified

aln = sys.argv[1]
sid_group_map = {}
if len(sys.argv) == 3 :
    id_file = sys.argv[2]
    with open(id_file, 'rt') as fp :
        while True:
            line = fp.readline()
            if not line : break
            sid, group_key = line.replace('\n', '').split()
            sid_group_map[sid] = group_key

#aln = "test.clu.gz"

#pp.pprint(sid_group_map); exit()

## load in data
DCT = {}
L = 0
with gzip.open(aln, 'rt') as fp :
    while True:
        line = fp.readline()
        if not line : break
        if line[0:6] == "Kalign" : continue
        if line != '\n' :
            regex = re.compile(r'(\S*)\s*(\S*)\n')
            ret = regex.search(line)
            if ret:
                sid = ret.group(1)
                seq = ret.group(2)
                if sid not in DCT :
                    DCT[sid] = seq
                else :
                    DCT[sid] += seq

                L = len(DCT[sid])

#pp.pprint(DCT); exit()

# check the data

for sid in DCT :
    if len(DCT[sid]) != L :
        sys.exit("Error length, check the data and the code\n")

# calculation
## pairwise differences

def pairwise_diff(seq1, seq2) :
    codes = ['A', 'T', 'C', 'G']
    L = len(seq1)
    if len(seq2) != L : exit("Wrong alignment\n")

    num_tot = 0
    den_tot = 0
    for i in range(L):
        x = seq1[i]
        y = seq2[i]
        if x not in codes or y not in codes: continue
        den_tot += 1
        if x != y : 
            num_tot += 1

    if den_tot > 0 :
        x = float(num_tot)/float(den_tot)
    else :
        x = 'NA'

    return(x)

def sorted_key(arr) : return('-'.join(sorted(arr)))

pairwise_div_dct = {}
sids = DCT.keys()
sids_pair_comb = combinations(sids, 2)
for p in list(sids_pair_comb) : 
    newk = sorted_key(p)
    seq1 = DCT[p[0]]
    seq2 = DCT[p[1]]
    pairwise_div_dct[newk] = pairwise_diff(seq1, seq2)

# divide seq into groups
ori_sid_group = {}

for sid in DCT :
    group_key = 'NA'
    for sid_key in sid_group_map :
        if sid_key in sid : 
            group_key = sid_group_map[sid_key]
            break
    if group_key == 'NA' : continue
    if group_key not in ori_sid_group :
        ori_sid_group[group_key] = []

    ori_sid_group[group_key].append(sid)

## calculate the observation result

def cal_average_p_distance(sid_group, pairwise_div_dct) :
    ret_dct = {}
    ## within group
    group_keys = list(sid_group.keys())
    group_keys_L = len(group_keys)
    for i in range(group_keys_L):
        group_key = group_keys[i]
        sids = sid_group[group_key]
        sids_comb = combinations(sids, 2)
        ave = 0
        n = 0
        for p in list(sids_comb) :
            k = sorted_key(p)
            ave += pairwise_div_dct[k]
            n += 1

        ret_dct[group_key] = ave/n
## between groups
    for i in range(group_keys_L - 1):
        for j in range(i+1, group_keys_L, 1) :
            group_key1 = group_keys[i]
            group_key2 = group_keys[j]
            ave = 0
            n = 0
            if group_key1 == group_key2 : continue 
            sids1 = sid_group[group_key1]
            sids2 = sid_group[group_key2]
            for sid1 in sids1 :
                for sid2 in sids2 :
                    k = sorted_key([sid1, sid2])
                    ave += pairwise_div_dct[k]
                    n += 1
            ret_dct[group_key1 + 'VS' + group_key2] = ave/n

    return(ret_dct)
 

ori_ob = cal_average_p_distance(ori_sid_group, pairwise_div_dct)

group_keys = list(ori_sid_group.keys())
group_keys_Len = {}
for k in group_keys :
    group_keys_Len[k] = len(ori_sid_group[k]) 

permu_out = {}

sids = list(DCT.keys())
for i in range(100000) :
    permu_sids = random.sample(sids, len(sids))
    new_sid_group = {}
    cur = 0
    for k in group_keys :
        rg = group_keys_Len[k]
        new_sid_group[k] = permu_sids[cur:(cur+rg)]
        cur += rg
    permu_ob = cal_average_p_distance(new_sid_group, pairwise_div_dct)
    for k in permu_ob :
        if k not in permu_out : permu_out[k] = []
        permu_out[k].append(permu_ob[k])

def cal_stat(arr) :
    n = len(arr)
    ave = 0
    for i in range(n) : ave += arr[i]
    ave = ave/n
    var = 0
    for i in range(n) : var += (arr[i] - ave) * (arr[i] - ave)
    var = var/(n-1)
    return([ave, var])

for k in ori_ob :
    x = ori_ob[k]
    x_left = 0
    x_right = 0
    x_mid = 0
    y_arr = permu_out[k]
    sim_ave = np.average(y_arr)
    sim_var = np.var(y_arr)
    for y in permu_out[k] :
        if y < x - 1e-8 : x_left += 1
        elif y > x + 1e-8: x_right += 1
        else : x_mid += 1
    print(k, x, sim_ave, sim_var, x_left, x_mid, x_right)
