#!/usr/bin/env python
from __future__ import print_function
import sys
import os
import re
import argparse
from collections import defaultdict

def parseArgs():
    desc = """Take a generated tbl and create a simple BED, one record per exon or CDS.  Uses for debugging"""
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("tblFile",
                        help="tbl input file")
    parser.add_argument("bedFile",
                        help="output BED file")
    return parser.parse_args()

def loadTbl(tblFh):
    """load into a dict of feature rows, split by tab for sequence"""
    seqFeatures = defaultdict(list)
    curFeat = None
    for line in map(lambda l: l[0:-1], tblFh):
        if line.startswith(">Features"):
            curFeat = line.split()[1]
        else:
            seqFeatures[curFeat].append(line.split('\t'))
    return seqFeatures

def is_feature_start(feature):
    return (feature[0] != "") and (len(feature) == 3)

def is_feature_coord(feature):
    return feature[0] != ""

def strip_incmpl(coord):
    return coord.replace("<", "").replace(">", "")

def pop_annot(features, ifeat):
    # advance to start of feature
    #  218019	218204	mRNA
    while ifeat < len(features):
        if is_feature_start(features[ifeat]):
            break
        ifeat += 1
    if ifeat == len(features):
        return ifeat, None

    # pop (coords, annot)
    coords = []
    annot = []
    while True:
        feature = features[ifeat]
        if is_feature_coord(feature):
            c = [int(strip_incmpl(feature[0])), int(strip_incmpl(feature[1]))]
            if len(feature) > 2:
                c.append(feature[2])
            coords.append(c)
        else:
            annot.append(feature)
        ifeat += 1
        if ifeat == len(features) or is_feature_start(features[ifeat]):
            break

    return ifeat, (coords, annot)

def process_feature(seqname, coords, annot, bedFh):
    strand = '+' if coords[0][0] < coords[0][1] else '-'
    name = coords[0][2]
    for coord in coords:
        if strand == '+':
            start, end = coord[0]-1, coord[1]
        else:
            start, end = coord[1]-1, coord[0]
        print(seqname, start, end, name, 0, strand, sep='\t', file=bedFh)

def process_features(seqname, features, bedFh):
    ifeat = 0
    while ifeat < len(features):
        ifeat, coords_annot = pop_annot(features, ifeat)
        if ifeat == len(features):
            break
        process_feature(seqname, coords_annot[0], coords_annot[1], bedFh)

def main(opts):
    with open(opts.tblFile) as tblFh:
        seqFeatures = loadTbl(tblFh)

    with open(opts.bedFile, "w") as bedFh:
        for seqname in sorted(seqFeatures.keys()):
            process_features(seqname, seqFeatures[seqname], bedFh)
        

main(parseArgs())
