import numpy as np
import math
import allel
import msprime
import tskit
from IPython.display import SVG, display
import sys
import matplotlib.pyplot as plt
import pandas as pd

#A High-Resolution Single Nucleotide Polymorphism Genetic Map of the Mouse Genome
#Sagiv Shifman,Jordana Tzenova Bell,Richard R Copley,Martin S Taylor,Robert W Williams,Richard Mott,Jonathan Flint 
#Published: November 14, 2006
#https://doi.org/10.1371/journal.pbio.0040395

NSSPR = 7
NSIND = 2
NSJPN = 29
NSKOR = 9
NSKAZ = 5

def f4ratio(mts, NS1, NS2, NS3, NS4, NS5):
    h = allel.HaplotypeArray(mts.genotype_matrix())
    subpops =[list(range(int(NS1))),
              list(range(int(NS1),int(NS1)+int(NS2))),
              list(range(NS1+NS2,NS1+NS2+NS3)),
              list(range(NS1+NS2+NS3,NS1+NS2+NS3+NS4)),
              list(range(NS1+NS2+NS3+NS4,NS1+NS2+NS3+NS4+NS5)),
             ]
    
    acIND = h.count_alleles(subpop=subpops[0], max_allele=1)
    acJPN = h.count_alleles(subpop=subpops[1], max_allele=1)
    acKOR = h.count_alleles(subpop=subpops[2], max_allele=1)
    acKAZ = h.count_alleles(subpop=subpops[3], max_allele=1)
    acSPR = h.count_alleles(subpop=subpops[4], max_allele=1)
    num1, den1 = allel.patterson_d(acKAZ, acSPR, acJPN, acIND)
    num2, den2 = allel.patterson_d(acKAZ, acSPR, acKOR, acIND)
  
    
    num1_sum = np.nansum(num1)
    num2_sum = np.nansum(num2)

    if num2_sum != 0 and num1_sum > 0 and num2_sum > 0:
        f4ratio = num1_sum/num2_sum
        if f4ratio < 0:
            f4ratio = 0
        elif f4ratio > 1:
            f4ratio = 1
    else:
        f4ratio = float('nan')
    return(f4ratio)

def Hudson_fst(mts, NS1, NS2, NS3, NS4, NS5):
    h = allel.HaplotypeArray(mts.genotype_matrix())
    subpops =[list(range(int(NS1))),
              list(range(int(NS1),int(NS1)+int(NS2))),
              list(range(NS1+NS2,NS1+NS2+NS3)),
              list(range(NS1+NS2+NS3,NS1+NS2+NS3+NS4)),
              list(range(NS1+NS2+NS3+NS4,NS1+NS2+NS3+NS4+NS5)),
             ]
    
    acIND = h.count_alleles(subpop=subpops[0], max_allele=1)
    acJPN = h.count_alleles(subpop=subpops[1], max_allele=1)
    acKOR = h.count_alleles(subpop=subpops[2], max_allele=1)
    acKAZ = h.count_alleles(subpop=subpops[3], max_allele=1)
    acSPR = h.count_alleles(subpop=subpops[4], max_allele=1)
    
    acMUS = acKOR+acKAZ
    
#     print(acKOR[1], acKAZ[1], acMUS[1])
    num1, den1 = allel.hudson_fst(acIND, acMUS)
    
    nsnp = len(num1)

    num_sum = np.nansum(num1)
    den_sum = np.nansum(den1)

    if den_sum != 0:
        fst = num_sum/den_sum
    else:
        fst = float('nan')
    return(fst, nsnp)


def simulation(MLE):
   
    seq_length = 20000
    recombination_rate = 6.3e-9
    
    demography = msprime.Demography()
    demography.add_population(initial_size=MLE['NIND'], growth_rate=0, name="IND")
    demography.add_population(initial_size=MLE['NJPN'], growth_rate=0, name="JPN")
    demography.add_population(initial_size=MLE['NKOR'], growth_rate=0, name="KOR")
    demography.add_population(initial_size=MLE['NKAZ'], growth_rate=0, name="KAZ")
    demography.add_population(initial_size=MLE['NSPR'], growth_rate=0, name="SPR")
    demography.add_population(initial_size=MLE['NMUS'], growth_rate=0, name="MUS")
    demography.add_population(initial_size=MLE['NMUSCAS'], growth_rate=0, name="MUSCAS")
    demography.add_population(initial_size=MLE['NCAS'], growth_rate=0, name="CAS")
    demography.add_population(initial_size=MLE['NANC'], growth_rate=0, name="ANC")
    demography.add_population(initial_size=MLE['NANCJPN'], growth_rate=0, name="ANCJPN")
#     demography.set_symmetric_migration_rate(populations=["JPN", "KOR"], rate=MIGMC)
    demography.set_symmetric_migration_rate(populations=["ANCJPN", "IND"], rate=MLE['MIGCAS'])
    demography.set_symmetric_migration_rate(populations=["KOR", "KAZ"], rate=MLE['MIGMUS'])
    demography.add_admixture(MLE['TMIX'], derived="JPN", ancestral=["ANCJPN","KOR"], proportions=[1-MLE['ADMIX'], MLE['ADMIX']])
#     demography.set_symmetric_migration_rate(populations=["JPN", "KOR"], rate=0)
    demography.set_symmetric_migration_rate(populations=["ANCJPN", "KOR"], rate=MLE['MIGMC'])
    demography.add_population_split(time=MLE['TDIV1'], derived=["IND", "ANCJPN"], ancestral="CAS")
    demography.set_symmetric_migration_rate(populations=["ANCJPN", "IND"], rate=0)
    demography.set_symmetric_migration_rate(populations=["ANCJPN", "KOR"], rate=0)
    demography.add_population_split(time=MLE['TDIV2'], derived=["KAZ", "KOR"], ancestral="MUS")
    demography.set_symmetric_migration_rate(populations=["KOR", "KAZ"], rate=0)
    demography.set_symmetric_migration_rate(populations=["CAS", "MUS"], rate=MLE['MIGMC'])
    demography.add_population_split(time=MLE['TDIV3'], derived=["CAS", "MUS"], ancestral="MUSCAS")
    demography.set_symmetric_migration_rate(populations=["CAS", "MUS"], rate=0)
    demography.add_population_split(time=MLE['TDIV4'], derived=["MUSCAS", "SPR"], ancestral="ANC")
    demography.sort_events()
    ts = msprime.sim_ancestry(samples={"IND": NSIND, "JPN": NSJPN, "KOR": NSKOR, "KAZ": NSKAZ, "SPR": NSSPR}, 
                              demography=demography, sequence_length= seq_length, recombination_rate=recombination_rate,
                             ploidy=2)
    mts = msprime.sim_mutations(ts, rate=5.7e-9)
    return(mts)

alpha_list = []
fst_list = []

num = 103

## this file is a .bestlhoods file with the highest likelihood in fastsimcoal2 parameter estimation
input_file = "../fastsimcoal/shell/5pop_gf/5pop_gf_"+str(num)+"/5pop_gf_"+str(num)+".bestlhoods"

with open(input_file) as input:
    data = input.readlines()
    values = data[1][:-1].split()
    keys = data[0][:-1].split()

keys[6] = 'NMUS'
keys[7] = 'NMUSCAS'
keys[10] = 'NANCJPN'
keys[11] = 'NANC'
keys[12] = 'ADMIX'

print(keys, values)
MLE = {}    
for key, value in zip(keys, values):
    MLE[key] = float(value)

MLE['TMIX'] = 3000

print(MLE)

alpha_list = []
fst_list = []

for i in range(100000):
    mts = simulation(MLE)
    alpha = f4ratio(mts, NSIND*2, NSJPN*2, NSKOR*2, NSKAZ*2, NSSPR*2)
    fst, nsnp = Hudson_fst(mts, NSIND*2, NSJPN*2, NSKOR*2, NSKAZ*2, NSSPR*2)
    print(i, alpha, fst, nsnp)
    if nsnp >= 100:
        alpha_list.append(alpha)
        fst_list.append(fst)
    
df = pd.DataFrame({'alpha':alpha_list, 'fst': fst_list})
df_drop_rows = df.dropna()

### df_drop_rows has two columns, alpha and fst with given parameters.
# results are written to a json file
with open("alpha_simulation.json", 'w') as outfile:
    json.dump(df_drop_rows, outfile)