#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# SK

import sys

def error(string, error_type=1):
    sys.stderr.write(f'ERROR: {string}\n')
    sys.exit(error_type)


def log(string, newline_before=False):
    if newline_before:
        sys.stderr.write('\n')
    sys.stderr.write(f'LOG: {string}\n')

#####

taxfile = sys.argv[1]
blastfile = sys.argv[2]

#####

# read in blast output
log('Reading in blast output ...')

ctr = 0
lastquery = None

high_conf_threshold = 1e-20
high_conf = {}
low_conf_threshold = 1e-10
low_conf = {}
rest_conf = {}

queries = set()
taxids = set()

with open(blastfile) as infh:
    for line in infh:

        # status
        if ctr%1000 == 0:
            sys.stderr.write(f'Parsed lines: {ctr}\r')
        ctr += 1

        lt = line.strip().split()
        query = lt[0]
        if query == lastquery:
            # always skip secondary hits
            continue
        assert query not in queries
        lastquery = query
        queries.add(query)

        evalue = float(lt[10])
        taxid = lt[12]
        # handle multiple taxids on row
        taxid = taxid.split(';', 1)[0]

        taxids.add(taxid)

        if evalue <= high_conf_threshold:
            
            if taxid in high_conf:
                high_conf[taxid] += 1
            else:
                high_conf[taxid] = 1

        elif evalue <= low_conf_threshold:

            if taxid in low_conf:
                low_conf[taxid] += 1
            else:
                low_conf[taxid] = 1

        else:

            if taxid in rest_conf:
                rest_conf[taxid] += 1
            else:
                rest_conf[taxid] = 1


log(f'Parsed lines: {ctr}')
log(f'Num queries: {len(queries)}')

highsum = sum([high_conf[tax] for tax in high_conf])
lowsum = sum([low_conf[tax] for tax in low_conf])
restsum = sum([rest_conf[tax] for tax in rest_conf])

log(f'High conf: {highsum}')
log(f'Low conf:  {lowsum}')
log(f'Rest conf: {restsum}')
log(f'Total:     {highsum+lowsum+restsum}')


#####
# read in taxonomy names
log('Reading in taxonomy names ...')

taxnames = {}

with open(taxfile) as infh:
    for line in infh:

        lt = line.strip().split('\t|\t')
        assert lt[3].endswith('\t|')

        # skip unseen taxids
        taxid = lt[0]
        if taxid not in taxids:
            continue

        # enforce correct splitting
        lt[3] = lt[3].replace('\t|', '')
        assert '\t' not in lt[3]
        assert '|' not in lt[3]

        name = lt[1]
        nclass = lt[3]

        # add only scientific name
        if nclass != 'scientific name':
            continue

        if taxid not in taxnames:
            taxnames[taxid] = {}

        # assert nclass not in taxnames[taxid], f'{nclass} already in taxnames: {taxnames[taxid]}'

        # add this name class
        taxnames[taxid][nclass] = name


#####
# output sets

print(f'# Blast output breakdown by NCBI taxids: {blastfile}')
print(f'# High confidence (<= {high_conf_threshold}):\t{highsum}')
print(f'# Low confidence  (<= {low_conf_threshold}):\t{lowsum}')
print(f'# Rest alig reads (> {low_conf_threshold}):\t{restsum}')
print(f'# Total aligned:\t{highsum+lowsum+restsum}')

for setname, conf_set in zip([f'High confidence - E-value <= {high_conf_threshold}',\
                              f'Low confidence - {high_conf_threshold} < E-value <= {low_conf_threshold}',\
                              f'rest confidence - E-value > {low_conf_threshold}'], [high_conf, low_conf, rest_conf]):

    records = []
    for tax in conf_set:
        assert tax in taxnames, f'{tax} not in taxnames: {taxnames}'
        if 'scientific name' not in taxnames[tax]:
            sciname = 'NA'
        else:
            sciname = taxnames[tax]['scientific name']

        # taxid, scientific name, number of reads
        records.append((tax, sciname, conf_set[tax]))

    records.sort(key=lambda x: x[2], reverse=True)

    # print out
    print('#\n# ' + setname)
    print('# taxid\tscientific_name\tread_count')
    for rec in records:
        print(f'{rec[0]}\t{rec[1]}\t{rec[2]}')
