#!/usr/bin/env python
#
# Copyright (c) 2016 10x Genomics, Inc. All rights reserved.
#
import os
import docopt
import vcf
import tenkit.bio_io as tk_io
import tenkit.tabix as tk_tabix
import tenkit.reference
import sys
import collections

__doc__ = '''
Standardize user-supplied VCF files for use as Long Ranger-compatible
precalled variants files. 

Steps include:
  1. Rename chromosomes/contigs to match the specified reference.
  2. Extract only the single specified sample. 
  3. Remove homozygous reference variants in the specified sample.
  4. Remove variants on contigs not in the specified reference.
  5. Sort the VCF file with standard tabix ordering.

The commands below should be preceded by 'longranger':

Usage:
    mkvcf --reference=PATH --sample=NAME [--keepquotes] <input_vcf>...
    mkvcf -h | --help | --version

Arguments:
    reference  Path to 10x-compatible reference folder.
    sample     Name of sample to extract from VCF input files.
    keepquotes Do not strip double quotes from each record's INFO field.
    input_vcf  One or more VCF files to be non-destructively standardized.

Options:
    -h --help    Show this message.
    --version    Show version.
'''

def error(msg):
    print msg
    sys.exit(1)

def fixpath(path):
    return os.path.abspath(os.path.expandvars(os.path.expanduser(path)))

def replace_double_quotes(struct, replace_str=""):
    '''
    Purge double quotes from any keys, values or items
    in the structure, by replacing with the specified
    string (default blank)
    '''
    if isinstance(struct, basestring):
        return struct.replace('"', replace_str)
    # sequence-y
    elif hasattr(struct, '__len__'):
        # dicty
        if hasattr(struct, 'iteritems'):
            return {replace_double_quotes(k):replace_double_quotes(v) for k, v in struct.iteritems()}
        else:
            return [replace_double_quotes(item) for item in struct]
    else:
        # literal or unknown
        return struct

args = docopt.docopt(__doc__, version=" ") # do not remove space
ref_path = fixpath(args["--reference"])
sample_name = args["--sample"]
strip_double_quotes = not args['--keepquotes']

inputs = [ fixpath(p) for p in args["<input_vcf>"] ]

# Verify reference folder exists
if not os.path.exists(ref_path):
    error("Reference path does not exist: %s" % ref_path)

if not os.path.isdir(ref_path):
    error("Please provide a reference folder, not a file: %s" % ref_path)

fasta = tenkit.reference.open_reference(ref_path) 
fasta_contigs = fasta.keys()

out_fn = "precalled_" + sample_name + ".vcf"
out_file = open(out_fn, 'w')

template_reader = vcf.Reader(open(inputs[0]))

# If reference has chr-style contigs and VCF does not, then update
# VCF contig names to chr-style
if fasta_contigs[0].startswith("chr"):
    new_contigs = []
    for (k, ctg) in template_reader.contigs.items():
        if not k.startswith("chr"):
            k = "chr" + k
            ctg = vcf.parser._Contig("chr" + ctg.id, ctg.length)

        if k in fasta_contigs:
            new_contigs.append((k,ctg))

    template_reader.contigs = collections.OrderedDict(new_contigs)

template_reader.samples = [sample_name]

# Add a sources line to record the processing script
sources = template_reader.metadata.setdefault("source", [])
sources.append("10X/longranger/bin/mkvcf %s" % os.getenv('TENX_VERSION', ''))
vfw = tk_io.VariantFileWriter(out_file, template_reader=template_reader)

num_vars = 0

for vcf_fn in inputs:
    print "Extracting sample %s from %s..." % (sample_name, vcf_fn)
    vfr = tk_io.VariantFileReader(vcf_fn)
    for record in vfr.record_getter():
        # If reference has chr-style contigs, fix variant to chr-style
        if fasta_contigs[0].startswith("chr"):
            if not record.CHROM.startswith('chr'):
                record.CHROM = 'chr' + record.CHROM

        # Drop variants on contigs not in this reference.
        if not record.CHROM in fasta_contigs:
            continue

        if strip_double_quotes:
            # purge double quotes to not trip up VCFALLELICPRIMITIVES
            info_keys = record.INFO.keys()
            for key in info_keys:
                record.INFO[key] = replace_double_quotes(record.INFO[key])
                if '"' in key:
                    record.INFO[key.replace('"','')] = record.INFO.pop(key)

        data = None
        sample = None
        for s in record.samples:
            if s.sample == sample_name:
                data = s.data
                sample = s
                continue

        if sample is None:
            print "Sample not found: %s" % sample_name
            sys.exit(1)

        record.samples = [sample]

        gt = getattr(data, 'GT', None)
        if gt is None or '.' in gt:
            continue
        if '/' in gt:
            genotype = [int(x) for x in gt.split('/')]
        else:
            genotype = [int(x) for x in gt.split('|')]

        # Remove homozygous reference
        if genotype == [0,0]:
            continue

        vfw.write_record(record)
        num_vars += 1
        if num_vars % 100000 == 0:
            print "Processed %d records" % num_vars

out_file.close()
tk_tabix.sort_vcf(out_fn, "sorted_" + out_fn)
os.remove(out_fn)
