#!/usr/bin/env python3

import sys
from collections import Counter
import dmstools
import fphd

"""
This script takes as input a PAF alignment file generated by minimap2
from PacBio long-read sequencing data and writes out a table of
barcode-to-variant mappings.
"""

if __name__ == "__main__":
    source_file_path = sys.argv[1]
    map_file_path = sys.argv[2]

    read_attribute_counts = Counter()

    barcode_to_variant_map = dict()

    # Parse BC-variant map from PAF file
    for record in dmstools.RecordReader(source_file_path):
        try:
            bc = record.alignment_subset(2629, 2649).upper()
            if len(bc) != 20:
                raise IndexError

            if bc not in barcode_to_variant_map:
                barcode_to_variant_map[bc] = Counter()

            try:
                variants = record.call_coding_variants(427, 1324)
                if len(variants) == 1:
                    barcode_to_variant_map[bc][variants[0]] += 1
                    read_attribute_counts["single_variant"] += 1
                elif len(variants) == 0:
                    barcode_to_variant_map[bc]["wildtype"] += 1
                    read_attribute_counts["wildtype"] += 1
                else:
                    barcode_to_variant_map[bc]["multiple_variants"] += 1
                    read_attribute_counts["multiple_variants"] += 1
            except ValueError:
                read_attribute_counts["indel"] += 1
            except IndexError:
                read_attribute_counts["read_too_short_to_call_variants"] += 1
        except IndexError:
            read_attribute_counts["bad_bc"] += 1

    barcode_attribute_counts = Counter()

    with open(map_file_path, "w") as out:
        out.write("BC\tvar_ref\tvar_pos\tvar_alt\tread_count\n")
        for barcode, matching_read_classes in barcode_to_variant_map.items():
            if len(matching_read_classes) == 0:
                barcode_attribute_counts["no_interpretable_reads"] += 1
                print(f"\"{barcode}\" was only seen in reads containing indels or too short to call variants")
                continue
            top_read_class, top_read_count = matching_read_classes.most_common()[0]
            top_read_fraction = top_read_count / matching_read_classes.total()
            if top_read_fraction <= 0.75:
                barcode_attribute_counts["ambiguous"] += 1
                print(f"\"{barcode}\" cannot be unambiguously mapped: " + ", ".join(map(lambda t: f"{t[0][0]}{t[0][1]}{t[0][2]} ({t[1]})", matching_read_classes.most_common())))
            elif top_read_class == "multiple_variants":
                barcode_attribute_counts["multiple_variants"] += 1
                print(f"\"{barcode}\" maps to multiple variants")
            elif top_read_class == "wildtype":
                barcode_attribute_counts["wildtype"] += 1
                out.write(f"{barcode}\tM\t1\tM\t" + str(top_read_count) + "\n")
            else:
                barcode_attribute_counts["single_variant"] += 1
                out.write(f"{barcode}\t" + "\t".join(map(str, top_read_class)) + "\t" + str(top_read_count) + "\n")
    
    print("\nRead Attributes:")
    read_total = read_attribute_counts.total()
    for attribute, count in read_attribute_counts.most_common():
        print(f"{attribute}: {count} ({count/read_total:.1%})")
    print(f"Total: {read_total} (100.0%)")
    
    print("\nBarcode Attributes:")
    barcode_total = barcode_attribute_counts.total()
    for attribute, count in barcode_attribute_counts.most_common():
        print(f"{attribute}: {count} ({count/barcode_total:.1%})")
    print(f"Total: {barcode_total} (100.0%)")
