import numpy as np
import math
import time

### mutation function
def mutation(allele):
    if np.random.uniform() >= 0.5:
        if allele < 20:
            allele += 1
    else:
        if allele > 1:
            allele -= 1
    return allele

def count_key(array, key):
    count = 0
    for i in range(len(array)):
        count += array[i][key]
    return count

def XYM_simulation_geneflow(parameters):
    N = parameters['N']
    ALPHA = parameters['ALPHA'] # autosomal incompatiblitiy, RAA in the manuscript
    BETA = parameters['BETA'] # X-autosome imcompatibility, RXA in the manuscript
    GAMMA = parameters['GAMMA'] # mitochondria-autosome incompatibility, RMA in the manuscript
    DELTA = parameters['DELTA'] # fertility reduction by SR distortion (i.e., X-Y incompatibility), RXY in the manuscript
    NMIG = parameters['NMIG'] ### migratin rate, in the scale of NM
    MU_Y = parameters['MU_Y'] # mutation rate of Y
    MU_X = parameters['MU_X'] # mutation rate of X
    RECOM = parameters['RECOM'] # recombination rate between SR and XHI locus
    generation = parameters['GEN'] # how many generation to run
    EPSILON = parameters['EPSILON'] # strength of SR distortion. 2 means that male birth rate is twice of female birth rate when incompatibility occur. So 2/3 male ratio at birth.
    MALE = parameters['MALE'] #male migration rate bias, 0.5 means equal migration rate
    DIRECTION = parameters['DIRECTION'] if 'DIRECTION' in parameters else 0 # 0 means high copy number to low copy number, 1 means opposite direction of introgression
    VERBOSE = parameters['VERBOSE'] if 'VERBOSE' in parameters else 100000000 # how many output per unit
    print(VERBOSE)
# migration rate is input as NM, and the program use M 
    M = NMIG/N
##### initialization
    # initial number of male and female is equal
    NM = NF = int(N/2)
    result = {}
    mean_Y_array = np.zeros(generation)
    mean_X_array = np.zeros(generation)
    mean_Y_hap_array = np.zeros(generation)
    mean_X_hap_array = np.zeros(generation)
    mean_G_array = np.zeros(generation)
    mean_M_array = np.zeros(generation)
    if DIRECTION == 0:
        M_list = [{'Y_PHE': 5, 'Y_HAP': 0, 'X_PHE':5, 'X_HAP':0, 'X_HI1':0, 'X_HI2':0, 'G':0, 'M':0}] * NM
        F_list = [{'X1_PHE':5, 'X1_HAP':0, 'X1_HI1':0, 'X1_HI2':0, 'X2_PHE':5, 'X2_HAP':0, 'X2_HI1':0, 'X2_HI2':0, 'G':0, 'M':0}] * NF
    elif DIRECTION == 1:
        M_list = [{'Y_PHE': 10, 'Y_HAP': 1, 'X_PHE':10, 'X_HAP':1, 'X_HI1':1, 'X_HI2':1, 'G':1, 'M':1}] * NM
        F_list = [{'X1_PHE':10, 'X1_HAP':1, 'X1_HI1':1, 'X1_HI2':1, 'X2_PHE':1, 'X2_HAP':1, 'X2_HI1':1, 'X2_HI2':1, 'G':1, 'M':1}] * NF        
 
    for g in range(generation):  
        NM = len(M_list)
        NF = len(F_list)
        mean_Y = count_key(M_list, 'Y_PHE')/NM
        mean_Y_haplotype = count_key(M_list, 'Y_HAP')/NM
        mean_X = (count_key(M_list, 'X_PHE')+count_key(F_list, 'X1_PHE')+count_key(F_list, 'X2_PHE'))/(NM+NF*2)
        mean_X_haplotype = (count_key(M_list, 'X_HAP')+count_key(F_list, 'X1_HAP')+count_key(F_list, 'X2_HAP'))/(NM+NF*2)
        mean_G = (count_key(M_list, 'G')+count_key(F_list, 'G'))/(NM+NF)
        mean_M = (count_key(M_list, 'M')+count_key(F_list, 'M'))/(NM+NF)

        mean_Y_array[g] = mean_Y
        mean_X_array[g] = mean_X
        mean_G_array[g] = mean_G
        mean_Y_hap_array[g] = mean_Y_haplotype
        mean_X_hap_array[g] = mean_X_haplotype
        mean_M_array[g] = mean_M
### output verbose progress
        if g%VERBOSE == 0:
            local_time = time.localtime(time.time())
            print(str(local_time.tm_hour)+":"+str(local_time.tm_min)+":"+str(local_time.tm_sec), g, N, M, NM, NF, mean_Y_haplotype, mean_X_haplotype, mean_G, mean_M)
        male_ind_list = [i for i in range(NM)]
        female_ind_list = [i for i in range(NF)]
### mutation phase
        if MU_Y != 0 or MU_X != 0:
            #  mutation male Y
            target_ind_array = np.random.choice(male_ind_list, size=np.random.poisson(MU_Y*NM, 1)[0], replace=False)
            for target_ind in target_ind_array:
                genotype = M_list[target_ind]['Y_PHE']
                new_genotype = mutation(genotype)
                M_list[target_ind]['Y_PHE'] = new_genotype

            ### mutation phase male X
            target_ind_array = np.random.choice(male_ind_list, size=np.random.poisson(MU_X*NM, 1)[0], replace=False)
            for target_ind in target_ind_array:
                genotype = M_list[target_ind]['X_PHE']
                new_genotype = mutation(genotype)
                M_list[target_ind]['X_PHE'] = new_genotype

            ### mutation phase female X1
            target_ind_array = np.random.choice(female_ind_list, size=np.random.poisson(MU_X*NF, 1)[0], replace=False)
            for target_ind in target_ind_array:
                genotype = F_list[target_ind]['X1_PHE']
                new_genotype = mutation(genotype)
                F_list[target_ind]['X1_PHE'] = new_genotype

            ### mutation phase female X2
            target_ind_array = np.random.choice(female_ind_list, size=np.random.poisson(MU_X*NF, 1)[0], replace=False)
            for target_ind in target_ind_array:
                genotype = F_list[target_ind]['X2_PHE']
                new_genotype = mutation(genotype)
                F_list[target_ind]['X2_PHE'] = new_genotype
### recombination phase
        if RECOM != 0:
            ### recombination on X
            target_ind_array = np.random.choice(female_ind_list, size=np.random.poisson(RECOM*NF, 1)[0], replace=False)
            for target_ind in target_ind_array:
                if F_list[target_ind]['X1_HI1'] != F_list[target_ind]['X2_HI1']:
                    buf = F_list[target_ind]['X1_HI1']
                    F_list[target_ind]['X1_HI1'] = F_list[target_ind]['X2_HI1']
                    F_list[target_ind]['X2_HI1'] = buf
            target_ind_array = np.random.choice(female_ind_list, size=np.random.poisson(RECOM*NF, 1)[0], replace=False)
            for target_ind in target_ind_array:
                if F_list[target_ind]['X1_HI2'] != F_list[target_ind]['X2_HI2']:
                    buf = F_list[target_ind]['X1_HI2']
                    F_list[target_ind]['X1_HI2'] = F_list[target_ind]['X2_HI2']
                    F_list[target_ind]['X2_HI2'] = buf

### mating phase
        i = 0
        new_M_list = []
        new_F_list = []
        # repeat until we get N ind. for the next generation
        while i < N:
            # select random male and female for mating
            male_ind = M_list[np.random.choice(male_ind_list, size=1)[0]]
            female_ind = F_list[np.random.choice(female_ind_list, size=1)[0]]
            # retrieve genotypes for each
            MY, MX = male_ind['Y_PHE'], male_ind['X_PHE']
            MY_haplotype, MX_haplotype = male_ind['Y_HAP'], male_ind['X_HAP']
            MG, FG = male_ind['G'], female_ind['G']
            MX_hi1, MX_hi2  = male_ind['X_HI1'], male_ind['X_HI2']
            FM = female_ind['M']
            # which X chromosome is transmitted from female
            if np.random.uniform() > 0.5:
                FX = female_ind['X1_PHE']
                FX_haplotype = female_ind['X1_HAP']
                FX_hi1 = female_ind['X1_HI1']
                FX_hi2 = female_ind['X1_HI2']
            else:
                FX = female_ind['X2_PHE']
                FX_haplotype = female_ind['X2_HAP']
                FX_hi1 = female_ind['X2_HI1']
                FX_hi2 = female_ind['X2_HI2']
            # migration phase
            buf = np.random.uniform()
            # if migrant is male
            if  buf <= M*MALE:
                if DIRECTION == 0:
                    (MY, MY_haplotype, MX_hi1, MX_hi2, MG) = (10, 1, 1, 1, 1)
                elif DIRECTION == 1:
                    (MY, MY_haplotype, MX_hi1, MX_hi2, MG) = (5, 0, 0, 0, 0)
            # if migrant is female
            elif buf <= M:
                if DIRECTION == 0:
                    (FX, FX_haplotype, FX_hi1, FX_hi2, FG, FM) = (10, 1, 1, 1, 1, 1)               
                elif DIRECTION == 1:
                    (FX, FX_haplotype, FX_hi1, FX_hi2, FG, FM) = (5, 0, 0, 0, 0, 0)               
            # a = np.exp((np.log(MY)-np.log(MX)))
            if MY > MX:
                a = EPSILON
            elif MY < MY:
                a = 1/EPSILON
            else:
                a = 1
            MF_ratio = a/(1+a)
            # minumum expected frequence of males should be 5%-95%. This usually does not happen in a reasonable parameters
            MF_ratio = 0.95 if MF_ratio >= 0.95 else MF_ratio
            MF_ratio = 0.05 if MF_ratio <= 0.05 else MF_ratio   
            # print(MF_ratio)
            if np.random.uniform() < MF_ratio:
                fitness = np.exp(-1*(ALPHA*np.abs(MG-FG) + BETA*np.abs((MX_hi1+MX_hi2)/2-(MG+FG)/2) + DELTA*np.abs(MF_ratio-0.5) + GAMMA*np.abs((MG+FG)/2-FM)))
                if np.random.uniform() > fitness:
                    continue
                new_M_list.append({'Y_PHE':MY, 'X_PHE':FX, 'G':(MG+FG)/2, 'Y_HAP':MY_haplotype, 'X_HAP':FX_haplotype, 'X_HI1':FX_hi1, 'X_HI2':FX_hi2, 'M':FM})
            else:
                fitness = np.exp(-1*(ALPHA*np.abs(MG-FG) + BETA*np.abs((FX_hi1+MX_hi1+FX_hi2+MX_hi2)/4-(MG+FG)/2) + DELTA*np.abs(MF_ratio-0.5) + GAMMA*np.abs((MG+FG)/2-FM)))
                if np.random.uniform() > fitness:
                    continue
                new_F_list.append({'X1_PHE':MX, 'X2_PHE':FX, 'G':(MG+FG)/2, 'X1_HAP':MX_haplotype, 'X2_HAP':FX_haplotype, 'X1_HI1':MX_hi1, 'X2_HI1':FX_hi1, 'X1_HI2':MX_hi2, 'X2_HI2':FX_hi2, 'M':FM})
            i += 1
            # print(MG_array)
        M_list = new_M_list
        F_list = new_F_list
    results = {'mean_Y':mean_Y_array.tolist(), 'mean_X':mean_X_array.tolist(), 'mean_Y_hap':mean_Y_hap_array.tolist(), 'mean_X_hap':mean_X_hap_array.tolist(), 'mean_G':mean_G_array.tolist(), 'mean_M':mean_M_array.tolist()}
    return(results)

