# Script to parse RepeatMasker simple output into BEDready format
# Usage: python /path/to/script.py -h

import sys
import argparse
import pandas as pd
from pathlib import Path
import numpy as np
import re

def cmdline_args():
    p = argparse.ArgumentParser(
        description="""
        Format a simplified RepMask output to create BEDready format (parsable to 6 and 12).
        Combines split fragments into single unit.
        """,
        usage='cat input.txt | python parse_simple_to_bedready.py [options] > output.txt',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    return p.parse_args()

def read_simple_from_stdin():
    return pd.read_table(sys.stdin, index_col=False, dtype = {'chr': str, 'id': str})

# parse arguments
args = cmdline_args()

# read repeatmasker file from stdin
df_raw = read_simple_from_stdin()

################################################################################

# single unit transposons
df_single = df_raw.query("split != 't'").copy()

df_single['thick_start'] = df_single['start']  # col7
df_single['thick_end'] = df_single['end']      # col8
df_single['block_count'] = 1                   # col10
df_single['len'] = df_single['end'] - df_single['start'] 
df_single['block_sizes'] =  df_single['len'].astype(str) # col11
df_single['block_starts'] = '0'                # col12

################################################################################

# split transposons
df_split = df_raw.query("split == 't'").copy()
df_split['thick_start'] = df_split['start']  # col7
df_split['thick_end'] = df_split['end']  # col8

df_split['len'] = df_split['end'] - df_split['start']
df_split['len_str'] = df_split['len'].astype(str)

df_split = (
    df_split.groupby(
        'id', as_index=False
    ).agg(
        start_min = ('start', 'min'),        # overall start (col2)
        end_max = ('end', 'max'),            # overall end (col3)
        block_count = ('id', 'count'),       # num of exons/units (col10)
        block_sizes = ('len_str', ','.join)  # sizes of the units (col11)
    ).merge(
        df_split, on='id'
    )
)

# for perc_div, need to calculate weighted arithmetic mean (col4)
df_split['perc_div_mult_len'] = df_split['perc_div'] * df_split['len']
df_split = (
    df_split.groupby(
        'id', as_index=False
    ).agg(
        len_sum = ('len', 'sum'),
        perc_div_agg = ('perc_div_mult_len', 'sum')
    ).merge(
        df_split, on='id'
    )
)

# round to 1dp
df_split['perc_div'] = np.round(
    df_split['perc_div_agg'] / df_split['len_sum'],
    decimals = 1
)

# for blockStarts, need to cat all the relative starts for start_min (col12)
df_split['start_rel'] = df_split['start'] - df_split['start_min']
df_split['start_rel_str'] = df_split['start_rel'].astype(str)

df_split = (
    df_split.groupby(
        'id', as_index=False
    ).agg(
        block_starts = ('start_rel_str', ','.join)
    ).merge(
        df_split, on='id'
    )
)

# keep only the first unit of each transposon
ind_first_occur = np.unique(df_split['id'], return_counts=True)[1].cumsum()[:-1]
df_split_uniq = df_split.iloc[[0] + list(ind_first_occur)].copy()

# modify the start and end to be that of the overall
df_split_uniq['start'] = df_split_uniq['start_min']
df_split_uniq['end'] = df_split_uniq['end_max']

################################################################################

# extract relevant columns
bed9n12cols = ['chr', 'start', 'end', 'repeat_family', 
    'perc_div', 'complement', 'thick_start', 'thick_end', 'color',
    'block_count', 'block_sizes', 'block_starts']
cols_to_keep = bed9n12cols + [
    'id', 'len', 'repmask_id', 'repeat_class', 'repeat_class_broad'
]

df = (
    pd.concat(
        [df_single[cols_to_keep], df_split_uniq[cols_to_keep]]
    ).sort_values(
        by=['chr', 'start']
    )
)

# write to standard output
df.to_csv(sys.stdout, sep='\t', na_rep='', index=False)
