# Script to format chromosome IDs and drop mitochondria entries from RepMask output.
# Usage: python /path/to/script.py -h

import sys
import argparse
import pandas as pd
from pathlib import Path
import numpy as np
import re

def cmdline_args():
    p = argparse.ArgumentParser(
        description="""
        Parses the `repeat_family` column to get additional columns for
        `repeat_class` and `repeat_class_broad`.

        Makes use of TE_color.csv and TE_color_broad.csv to provide the color mappings.  
        """,
        usage='cat input.txt | python parse_repeatfamily_column.py [options] > output.txt',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )   
    return p.parse_args()

def read_from_stdin():
    return pd.read_table(sys.stdin, dtype='object', index_col=False, header=0, comment='#')

# parse arguments
args = cmdline_args()

# read from stdin
df = read_from_stdin()

# get file containing the color mapping for the TEs
color_dict = pd.read_csv(Path(__file__).resolve().parent.joinpath("TE_color.csv"),
    header=None, skiprows=2).set_index(0).to_dict()[1]
color_dict_broad = pd.read_csv(Path(__file__).resolve().parent.joinpath("TE_color_broad.csv"),
    header=None, skiprows=2).set_index(0).to_dict()[1]

# attempts to match the repeat's name to the color dictionary
# we go less and less granular until we find something
# repeat_full -> repeat_hier1 -> repeat_hier2 -> repeat_hier3
# e.g. DNA/TcMar-Mariner-3 ->  DNA/TcMar-Mariner -> DNA/TcMar -> DNA
df['repeat_hier1'] = [re.sub("-.*$", "", x) for x in df['repeat_family']]
df['repeat_hier2'] = [re.sub("-.*$", "", x) for x in df['repeat_hier1']]
df['repeat_hier3'] = [re.sub("/.*$", "", x) for x in df['repeat_hier2']]

df['repeat_class'] = np.nan
for repeat_level in ['repeat_family', 'repeat_hier1', 'repeat_hier2', 'repeat_hier3']:
    to_update_bool = ~df['repeat_class'].isin(color_dict.keys())
    df.loc[to_update_bool, 'repeat_class'] = df.loc[to_update_bool, repeat_level]
df['repeat_class_broad'] = df['repeat_hier3']

# define color column (broad categories)
df['color'] = df['repeat_class_broad'].map(color_dict_broad)

df = df.drop(['repeat_hier1', 'repeat_hier2', 'repeat_hier3'], axis=1)

# print to stdout
df.to_csv(sys.stdout, sep='\t', na_rep='', index=False, header=True)
