import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--isoform_file', type=str,help='clean (use clean_psl.py) isoform psl file')
parser.add_argument('-r', '--read_files', type=str,help='comma-separated list of clean (use clean_psl.py) R2C2 read psl files')
parser.add_argument('-s', '--sqanti_file', type=str,help='SQANTI classification file')
parser.add_argument('-o', '--output_file', type=str)
parser.add_argument('-g', '--gtf_file', type=str, help='same gtf file used to run SQANTI (did I mentioned you have to run SQANTI first?)')

args = parser.parse_args()
isoform_file = args.isoform_file
sqanti_file = args.sqanti_file
read_files = args.read_files.split(',')
outfile_handle = open(args.output_file,'w')
gtf_file = args.gtf_file


def read_sqanti_classification(sqanti_file):
    gene_dict={}
    for line in open(sqanti_file):
        a=line.strip().split('\t')
        isoform=a[0]
        gene=a[6]
        chromosome=a[1]
        gene_dict[isoform]=(gene,chromosome)
    return gene_dict

def read_gtf_file(gtf_file):
    gene_dict={}
    for line in open(gtf_file):
        a=line.strip().split('\t')
        if len(a)>6:

            type1=a[2]
            info=a[8]
            if type1=='gene':
                gene_id=info.split('gene_id "')[1].split('"')[0]
                gene_symbol=info.split('gene_name "')[1].split('"')[0]
                gene_dict[gene_id]=(gene_symbol,a[3],a[4])
    return gene_dict

def parse_clean_psl(psl_file,target_chromosome):
    psl_dict={}
    for line in open(psl_file):
        a = line.strip().split('\t')
        chromosome = a[13]
        if chromosome==target_chromosome:
            start=int(a[15])
            end=int(a[16])
            direction=a[8]
            name=a[9]
            begin, span = int(a[15]), int(a[16])
            blocksizes = a[18].split(',')[:-1]
            blockstarts = a[20].split(',')[:-1]
            readstarts = a[19].split(',')[:-1]
            psl_dict[name]=[]
            for index in np.arange(0,len(blocksizes),1):
                blockstart=int(blockstarts[index])
                blocksize=int(blocksizes[index])
                blockend=blockstart+blocksize
                psl_dict[name].append(blockstart)
                psl_dict[name].append(blockend)
    return  psl_dict

def generate_match(isoform_dict):
    pointer={}
    reverse_pointer={}

    for name,features in isoform_dict.items():
        reverse_pointer[name]={}
        start=features[0]
        end=features[-1]
        for base1 in range(start-10,start+10,1):
            for base2 in range(end-10,end+10,1):
                if not pointer.get((base1,base2)):
                    pointer[(base1,base2)]=[]
                pointer[(base1,base2)].append(name)
        for index in range(0,len(features),1):
            feature=features[index]
            reverse_pointer[name][index]={}
            for base in range(feature-5,feature+5,1):
                reverse_pointer[name][index][base]=1

    return pointer,reverse_pointer


def sorting_reads(pointer,reverse_pointer,read_file,isoform_dict,chromosome):



    for isoform in reverse_pointer:
        if not isoform_dict.get(isoform):
            isoform_dict[isoform]={}
        isoform_dict[isoform][read_file]=0

    read_dict = parse_clean_psl(read_file,chromosome)
    for read,features in read_dict.items():
        start,end=features[0],features[-1]
        if not pointer.get((start,end)):
            continue
        else:
            putative_matches=pointer[(start,end)]
            for isoform in putative_matches:
                Fail=False
                for index in range(0,len(features),1):
                    if not reverse_pointer[isoform].get(index):
                        Fail=True
                        break
                    else:
                        if not reverse_pointer[isoform][index].get(features[index]):
                            Fail=True
                            break
                if Fail==False:
#                    print (read,features,reverse_pointer[isoform])
                    isoform_dict[isoform][read_file]+=1

    return isoform_dict

def write_output(out,count_dict,read_files,gene_dict,gene_symbols):
    connect_dict={}
    not_matched=0
    for isoform in sorted(count_dict):
        if not gene_dict.get(isoform):
            not_matched+=1
            info=(str(not_matched),'-')
        else:
            info=gene_dict[isoform]
        if not connect_dict.get(info):
            connect_dict[info]={}
        connect_dict[info][isoform]=[]
        for read_file in read_files:
            connect_dict[info][isoform].append(str(count_dict[isoform][read_file]))

    for gene,chromosome in connect_dict:
        if gene=='ENSG00000204525.16':
             print(gene)
        if gene in gene_symbols:
            gene_symbol,start,end=gene_symbols[gene]
        else:
            gene_symbol='-'
            start='-'
            end='-'
        for isoform,expression in connect_dict[(gene,chromosome)].items():
            out.write(gene+'_'+gene_symbol+'_'+chromosome+'_'+start+'_'+end+'~')

            out.write(isoform+'\t')
            values=('\t').join(expression)
            out.write(values+'\n')




def collect_chromosomes(isoform_psl):
    chromosomes=set()
    for line in open(isoform_psl):

        a=line.strip().split('\t')
        chromosome=a[13]
        chromosomes.add(chromosome)
    chromosomes=sorted(list(chromosomes))
    return chromosomes

def main():
    chromosomes=collect_chromosomes(isoform_file)
    gene_dict=read_sqanti_classification(sqanti_file)
    for chromosome in chromosomes:
        print('processing',chromosome)
        print('reading isoforms')
        isoform_dict = parse_clean_psl(isoform_file,chromosome)
#        for isoform in sorted(isoform_dict):
#            print(isoform,isoform_dict[isoform])
        print('processing isoforms')
        pointer,reverse_pointer = generate_match(isoform_dict)
        isoform_dict={}
        for read_file in read_files:
            print('processing file ',read_file)
            isoform_dict=sorting_reads(pointer,reverse_pointer,read_file,isoform_dict,chromosome)
        print('writing output')
        gene_symbols=read_gtf_file(gtf_file)
        write_output(outfile_handle,isoform_dict,read_files,gene_dict,gene_symbols)


main()
