#!/usr/bin/env python

from collections import defaultdict
import re
import gzip
from sys import argv, exit


"""
sgrna_id_seq.fa
>sgrna_1
CATCTTCTTTCACCTGAACG

sgrna_gene2seq.txt
A1BG    CATCTTCTTTCACCTGAACG
"""


# Function opening gzipped file or not
def open_file(infile, mode='r'):
    if re.search('.gz$', infile) or re.search('.gzip$', infile):
        if mode == 'r':
            mode = 'rt'
        return gzip.open(infile, mode=mode)
    else:
        return open(infile, mode=mode)
    

def parse_fasta(fasta):
    name2seq = {}
    with open_file(fasta) as fh:
        line_num = 0
        seq = ''
        name = ''
        for line in fh:
            line = line.strip()
            if line_num % 2 == 0:
                name = line.split(' ')[0].split('/')[0][1:]
            else:
                seq = line
            name2seq[name] = seq
            line_num += 1
            
    return name2seq


def build_map_sg_id2gene(sgrna_id_seq_fasta, sgrna_gene2seq_file):
    sg_seq2gene = {}
    sg_id2gene = {}
    sg_id2seq = parse_fasta(sgrna_id_seq_fasta)
    with open_file(sgrna_gene2seq_file) as fh:
        for line in fh:
            line = line.strip()
            lineL = line.split('\t')
            sg_seq2gene[lineL[1]] = lineL[0].split(' ')[0]

    for id in sg_id2seq:
        sg_id2gene[id] = sg_seq2gene[sg_id2seq[id]]

    return sg_id2gene


def parse_idxstats(idxstats_file):
    sg_id2count = defaultdict(list)
    with open_file(idxstats_file) as fh:
        for line in fh:
            if re.search('^\*', line):
                continue
            line = line.strip()
            lineL = line.split('\t')
            sg_id2count[lineL[0]] = int(lineL[2])
            
    return sg_id2count


def main():
    try:
        sgrna_id_seq_fasta, sgrna_gene2seq_file, idxstats_file, output_file = argv[1], argv[2], argv[3], argv[4]
    except:
        exit(0)
    
    output = open(output_file, 'w')
    sg_id2gene = build_map_sg_id2gene(sgrna_id_seq_fasta, sgrna_gene2seq_file)
    sg_id2count = parse_idxstats(idxstats_file)

    for id in sg_id2count:
        output.write('\t'.join([id, sg_id2gene[id], str(sg_id2count[id])]) + '\n')

    output.close()
    
    
if __name__ == '__main__':
    main()