import sys
import os
from os import listdir
import pandas as pd

configfile: "config/config.yaml"

ngs_data_type = config['NGS_data_type'].strip().upper()
project_title = config['project_title'].strip()
if project_title == '':
    project_title = 'Untitled'

work_dir = config['workdir'].strip()
project_info_f = config['project_info'].strip()
resource_alloc = dict()

if not os.path.exists(work_dir):
    print('The working directory does not exist, please create it before running the pipeline: %s' % config['workdir'].strip())
    sys.exit(-1)

raw_data_folder = config['raw_data']

if not os.path.exists(os.path.join(work_dir, raw_data_folder)):
    os.mkdir(os.path.join(work_dir, raw_data_folder))

if not os.path.exists(os.path.join(work_dir, 'Preprocessing')):
    os.mkdir(os.path.join(work_dir, 'Preprocessing'))

patient_sample_collection = dict()
patient_to_sample_pairs = dict()
input_format = None

if not os.path.exists(project_info_f):
    print('The project info table does not exist, please create it before running the pipeline: %s' % config['project_info'].strip())
    sys.exit(-1)

count = 0
with open(project_info_f, 'r') as handle:
    for line in handle:
        line = line.strip()
        line = line.split('\t')
        if count == 0:
            if line[0] != 'PatientID' or line[1] != 'TumorName' or line[2] != 'TumorPath' or line[3] != 'NormalName' or line[4] != 'NormalPath' or line[5] != 'DataType':
                print('The project info table is not correct, please modify it before running the pipeline: %s' % config['project_info'].strip())
                sys.exit(-1) 
            count += 1
            continue
            
        if len(line) < 6:
            continue
        patient = line[0].strip()
        tumor_sample = line[1].strip()
        tumor_path = line[2].strip()
        normal_sample = line[3].strip()
        normal_path = line[4].strip()
        _input_format = line[5].strip().lower()

        if _input_format not in ['bam', 'fastq']:
            print('Input data format can only be bam or fastq: %s' % config['project_info'].strip())
            sys.exit(-1) 

        if input_format is None:
            input_format = _input_format
        else:
            if input_format != _input_format:
                print('Only one format is allowed for the input data: %s' % config['project_info'].strip())
                sys.exit(-1)

        if patient not in patient_to_sample_pairs:
            patient_to_sample_pairs[patient] = []

        if not patient in patient_sample_collection:
            patient_sample_collection[patient] = dict()
        if not tumor_sample in patient_sample_collection[patient]:
            patient_sample_collection[patient][tumor_sample] = tumor_path
        if not normal_sample in patient_sample_collection[patient]:
            patient_sample_collection[patient][normal_sample] = normal_path
        patient_to_sample_pairs[patient].append([normal_sample, tumor_sample])
        count += 1

## check the paths of input data if there is any that cannot access

for patient in patient_sample_collection:
    if input_format == 'bam':
        for sample in patient_sample_collection[patient]:
            if not os.path.exists(patient_sample_collection[patient][sample]):
                print('This file does not exist: %s. Please double check it and correct the project info table.' % patient_sample_collection[patient][sample])
                sys.exit(-1)  

    else:
        for sample in patient_sample_collection[patient]:
            file_paths = patient_sample_collection[patient][sample]
            file_paths = file_paths.split(',')
            if len(file_paths) > 2:
                print('Error: %s. This pipeline only supports one (single-end) or two fastq files (pair-end) for each sample. Please double check it and correct the project info table.' % patient_sample_collection[patient][sample])
                sys.exit(-1) 
            for file_path in file_paths:
                if not os.path.exists(file_path):
                    print('This file does not exist: %s. Please double check it and correct the project info table.' % file_path)
                    sys.exit(-1)                  

## create symbolic links in the working directory for raw data
for patient in patient_sample_collection:
    if not os.path.exists(os.path.join(work_dir, raw_data_folder, patient)):
        os.mkdir(os.path.join(work_dir, raw_data_folder, patient))
    
    for sample in patient_sample_collection[patient]:
        if input_format == 'bam':
            file_path = os.path.join(work_dir, raw_data_folder, patient, patient + '_' + sample + '.bam')
            if not os.path.exists(file_path):
                os.symlink(patient_sample_collection[patient][sample], file_path)
            else:
                if os.path.islink(file_path):
                    if os.readlink(file_path) != patient_sample_collection[patient][sample]:
                        os.unlink(file_path)
                        os.symlink(patient_sample_collection[patient][sample], file_path)
        else:
            # make a copy in the Raw directory and also make a copy in the Preprocessing directory
            source_file_paths = patient_sample_collection[patient][sample]
            source_file_paths = source_file_paths.split(',')
            if len(source_file_paths) == 1:
                ## single end sequencing
                file_path = os.path.join(work_dir, raw_data_folder, patient, patient + '_' + sample + '.fq.gz')
                file_path_copy = os.path.join(work_dir, "Preprocessing", patient, patient + '_' + sample + '.fq.gz')
                if not os.path.exists(file_path):
                    os.symlink(source_file_paths[0], file_path)
                    os.symlink(source_file_paths[0], file_path_copy)
                else:
                    if os.path.islink(file_path):
                        if os.readlink(file_path) != source_file_paths[0]:
                            os.unlink(file_path)
                            os.symlink(source_file_paths[0], file_path)
                            os.unlink(file_path_copy)
                            os.symlink(source_file_paths[0], file_path_copy)
            else:
                ## paired end sequencing
                file_path_r1 = os.path.join(work_dir, raw_data_folder, patient, patient + '_' + sample + '_R1.fq.gz')
                file_path_r2 = os.path.join(work_dir, raw_data_folder, patient, patient + '_' + sample + '_R2.fq.gz')
                file_path_r1_copy = os.path.join(work_dir, "Preprocessing", patient, patient + '_' + sample + '_R1.fq.gz')
                file_path_r2_copy = os.path.join(work_dir, "Preprocessing", patient, patient + '_' + sample + '_R2.fq.gz')

                if not os.path.exists(os.path.dirname(file_path_r1_copy)):
                    os.mkdir(os.path.dirname(file_path_r1_copy))

                if not os.path.exists(file_path_r1):
                    os.symlink(source_file_paths[0], file_path_r1)
                else:
                    if os.path.islink(file_path_r1):
                        if os.readlink(file_path_r1) != source_file_paths[0]:
                            os.unlink(file_path_r1)
                            os.symlink(source_file_paths[0], file_path_r1)

                if not os.path.exists(file_path_r1_copy):
                    os.symlink(source_file_paths[0], file_path_r1_copy)
                else:
                    if os.path.islink(file_path_r1_copy):
                        if os.readlink(file_path_r1_copy) != source_file_paths[0]:
                            os.unlink(file_path_r1_copy)
                            os.symlink(source_file_paths[0], file_path_r1_copy)

                if not os.path.exists(file_path_r2):
                    os.symlink(source_file_paths[1], file_path_r2)
                else:
                    if os.path.islink(file_path_r2):
                        if os.readlink(file_path_r2) != source_file_paths[1]:
                            os.unlink(file_path_r2)
                            os.symlink(source_file_paths[1], file_path_r2)                                    

                if not os.path.exists(file_path_r2_copy):
                    os.symlink(source_file_paths[1], file_path_r2_copy)
                else:
                    if os.path.islink(file_path_r2_copy):
                        if os.readlink(file_path_r2_copy) != source_file_paths[1]:
                            os.unlink(file_path_r2_copy)
                            os.symlink(source_file_paths[1], file_path_r2_copy)


resource_alloc = config['WES']
resource_alloc_tag = 'WES'

if ngs_data_type == 'WGS':
    resource_alloc = config['WGS']
    resource_alloc_tag = 'WGS'

if resource_alloc_tag == 'WES':
    lsf_handle = open('config/wes_lsf.yaml', 'r')
    output_handle = open(config['workdir'].strip() + '/lsf.yaml', 'w')
    lsf_content = lsf_handle.read()
    lsf_content = lsf_content.replace('###', project_title)
    output_handle.write(lsf_content)
    output_handle.close()
    lsf_handle.close()
else:
    lsf_handle = open('config/wgs_lsf.yaml', 'r')
    output_handle = open(config['workdir'].strip() + '/lsf.yaml', 'w')
    lsf_content = lsf_handle.read()
    lsf_content = lsf_content.replace('###', project_title)
    output_handle.write(lsf_content)
    output_handle.close()
    lsf_handle.close()

# ## determine all the output files, which are stored under Preprocessing directory and within each patient's own folder.  

def get_output_files():
    output_file_collection = []
    for patient in patient_sample_collection.keys():
        for sample in patient_sample_collection[patient]:
            output_file_collection.append("Preprocessing/%s/%s_%s_output-GATK.bam" % (patient, patient, sample))
            output_file_collection.append("Preprocessing/%s/%s_%s_output-GATK.bai" % (patient, patient, sample))
            output_file_collection.append("QC/PreQC/%s_%s_R1_fastqc.html" % (patient, sample))
            output_file_collection.append("QC/PreQC/%s_%s_R2_fastqc.html" % (patient, sample))
            output_file_collection.append("QC/PostQC/%s_%s_R1_val_1_fastqc.html" % (patient, sample))
            output_file_collection.append("QC/PostQC/%s_%s_R2_val_2_fastqc.html" % (patient, sample))
            output_file_collection.append("QC/AddRG/%s_%s_summary_metrics.txt" % (patient, sample))
            output_file_collection.append("QC/MarkDup/%s_%s_summary_metrics.txt" % (patient, sample))

        output_file_collection.append("SNVCalling/MuSE2/%s_%s_%s.vcf.gz" % (patient, "normal", "tumor"))
        output_file_collection.append("SNVCalling/Strelka2/%s_%s_%s.vcf.gz" % (patient, "normal", "tumor"))
        output_file_collection.append("SNVCalling/RawMerge/%s_%s_%s.vcf" % (patient, "normal", "tumor"))
        output_file_collection.append("SNVCalling/RawMergeAdjustedDP/%s_%s_%s.vcf" % (patient, "normal", "tumor")) 
        output_file_collection.append("SNVCalling/IndividialMAF/%s_%s_%s.maf" % (patient, "normal", "tumor"))
    
    output_file_collection.append("QC/MultiQC/pre_qc.html")
    output_file_collection.append("QC/MultiQC/post_qc.html")
    output_file_collection.append("QC/MultiQC/aligned_bam_qc.html")
    output_file_collection.append("QC/MultiQC/mark_duplicates_qc.html")
    output_file_collection.append("SNVCalling/FinalMAF/final.maf")
            
    return output_file_collection

workdir: config['workdir'].strip()

patient_to_samples = dict()
for patient in patient_sample_collection:
    for sample in patient_sample_collection[patient]:
        if patient not in patient_to_samples:
            patient_to_samples[patient] = []
        patient_to_samples[patient].append(sample)

def get_samples_for_patient(wildcards):
    return patient_sample_collection[wildcards.patient]



rule all:
    input:
        get_output_files()

include: "rules/index_ref.smk"
include: "rules/preprocessing.smk"
include: "rules/muse2.smk"
include: "rules/strelka2.smk"
include: "rules/merge_calls.smk"
include: "rules/adjust_read_depth_according_to_insert_size.smk"
include: "rules/vcf2maf.smk"
include: "rules/postprocessing.smk"
include: "rules/multiqc.smk"
