import pandas as pd
from tqdm import tqdm
import pysam
from scipy.stats import fisher_exact

def make_pcawg_df(snv_file: str):
    output_columns = ['chr', 'pos', 'ref', 'var', 'confidence', 'wgs_ref_reads', 'wgs_var_reads', 'wgs_vaf']
    vcf_columns = ['chr', 'pos', 'id', 'ref', 'var', 'qual', 'filter', 'info']
    df = pd.read_csv(snv_file, comment='#', names=vcf_columns, sep='\t', index_col=False)
    df = df.sort_values(by=['chr', 'pos'])
    df = df[df.chr.isin([f'chr{chrom}' for chrom in range(1, 23)])]
    df = df[df['filter'] == 'PASS']
    df['confidence'] = 0
    print('Before filtering',len(df))
    
    # WGS reads from tumor sample
    necessary_fields = ['t_alt_count', 't_ref_count']
    for field in necessary_fields:
        df = df[df['info'].str.contains(field)]

    df['wgs_ref_reads'] = df['info'].str.split(';').str[-1].str.split('=').str[-1].astype(int)
    df['wgs_var_reads'] = df['info'].str.split(';').str[-2].str.split('=').str[-1].astype(int)
    df['wgs_vaf'] = df.wgs_var_reads / (df.wgs_var_reads + df.wgs_ref_reads)
    return df[output_columns]

def add_atac_reads(df, atac_bam_file: str):
    df['atac_ref_reads'] = 0
    df['atac_var_reads'] = 0
    df['atac_vaf'] = 0

    # Add ATAC for ref and alt from bam file
    bam = pysam.AlignmentFile(atac_bam_file, 'r')
    pysam_order = ['A', 'C', 'G', 'T']

    for i, row in tqdm(df.iterrows(), total=len(df)):
        coverage = bam.count_coverage(
            contig=row.chr,
            start=row.pos - 1,
            stop=row.pos,
        )
        cov = dict({c: coverage[i][0] for i, c in enumerate(pysam_order)})
        r, a = cov[row['ref']], cov[row['var']]
        df.loc[i, 'atac_ref_reads'] = r
        df.loc[i, 'atac_var_reads'] = a
        df.loc[i, 'atac_vaf'] = a / (r + a) if (r + a) != 0 else None

def filter_snv(patient_df):
    patient_df = patient_df[
        (patient_df.wgs_vaf > 0.4) & (patient_df.wgs_vaf < 0.6)] 
    patient_df = patient_df[(patient_df.atac_vaf > 0) & (patient_df.atac_vaf < 1)]  # for no FC issues

    patient_df['pval'] = patient_df[['atac_ref_reads', 'wgs_ref_reads', 'atac_var_reads', 'wgs_var_reads']].apply(lambda x: fisher_exact([[x.atac_ref_reads, x.wgs_ref_reads], [x.atac_var_reads, x.wgs_var_reads]])[1], axis=1)
    patient_df['odds_ratio'] = patient_df[['atac_ref_reads', 'wgs_vaf', 'atac_var_reads']].apply(lambda x: (x.atac_ref_reads/(1-x.wgs_vaf))/(x.atac_var_reads/x.wgs_vaf), axis=1)

    patient_df = patient_df[((patient_df.odds_ratio > 1.33) | (patient_df.odds_ratio < 0.67)) & (patient_df.pval < 0.01)]
    patient_df['total_atac_reads'] = patient_df.atac_ref_reads + patient_df.atac_var_reads
    patient_df['total_wgs_reads'] = patient_df.wgs_ref_reads + patient_df.wgs_var_reads
    return patient_df

if __name__ == "__main__":
    patients = ['TCGA-A6-A567', 'TCGA-B9-A44B', 'TCGA-HE-A5NH', 'TCGA-QG-A5YV']
    for patient in patients:
        print(patient)
        df = make_pcawg_df(f'../data/{patient}_pcawg_hg38.vcf.gz')
        add_atac_reads(df, f'../data/{patient}_atac.bam')
        output_columns = ['chr', 'pos', 'ref', 'var', 'confidence', 'wgs_ref_reads', \
                          'wgs_var_reads', 'wgs_vaf', 'atac_ref_reads', 'atac_var_reads', 'atac_vaf']
        df.dropna(inplace=True)
        mut_df = df[output_columns]
        n_muts = len(mut_df)
        mut_filtered = filter_snv(mut_df)
        n_muts_filtered = len(mut_filtered)
        print(f'{n_muts_filtered}/{n_muts} mutations retained.')
        mut_filtered.to_csv(f'{patient}_SNV.csv', index=False, header=True)
