/*
 * 1. Verify validity of parameters and read input VCF files
 * 2. Build PBWT of reference panel
 * 3. do mpsc benchmark 1 on query panel
 *      find fragile regions
 *      build graph
 *      find length maximal mpsc
 *      count set maximal match mpsc solution space size
 */

#include<iostream>
#include<iomanip>
#include<fstream>
#include<sstream>
#include<string>
#include<vector>
#include<algorithm>
#include<cstdlib>

//-----------------------------------------------------------------------------
//source for time: https://stackoverflow.com/questions/17432502/how-can-i-measure-cpu-time-and-wall-clock-time-on-both-linux-windows
//  Windows
#ifdef _WIN32
#include <Windows.h>
double get_wall_time(){
    LARGE_INTEGER time,freq;
    if (!QueryPerformanceFrequency(&freq)){
        //  Handle error
        return 0;
    }
    if (!QueryPerformanceCounter(&time)){
        //  Handle error
        return 0;
    }
    return (double)time.QuadPart / freq.QuadPart;
}
double get_cpu_time(){
    FILETIME a,b,c,d;
    if (GetProcessTimes(GetCurrentProcess(),&a,&b,&c,&d) != 0){
        //  Returns total user time.
        //  Can be tweaked to include kernel times as well.
        return
            (double)(d.dwLowDateTime |
                    ((unsigned long long)d.dwHighDateTime << 32)) * 0.0000001;
    }else{
        //  Handle error
        return 0;
    }
}

//  Posix/Linux
#else
#include <time.h>
#include <sys/time.h>
double get_wall_time(){
    struct timeval time;
    if (gettimeofday(&time,NULL)){
        //  Handle error
        return 0;
    }
    return (double)time.tv_sec + (double)time.tv_usec * .000001;
}
double get_cpu_time(){
    return (double)clock() / CLOCKS_PER_SEC;
}
#endif
//-----------------------------------------------------------------------------

void ReadVCF(std::string inFile, std::string qinFile, int & M, 
        int & qM, int & N, bool ** & panel);
void PBWT(bool **panel, int **prefix, int **divergence, 
        int **u, int **v, int M, int N);
void benchmark(bool **panel, int **prefix, int **divergence, 
        int **u, int **v, int M, int qM, int N, std::ofstream & out);

void ReadVCF(std::string inFile, std::string qinFile, int & M, 
        int & qM, int & N, bool ** & panel){
    using namespace std;
    {//count M and N
        string line = "1#";
        ifstream in(inFile);
        if (!in.is_open()){
            std::cout << inFile << " failed to open!\n";
            exit(1);
        }
        stringstream linestr;
        while (line[1] == '#')
            getline(in, line);


        linestr.str(line);
        linestr.clear();
        for (int i = 0; i<9;++i)
            linestr>>line;
        N = M = 0;
        while (!linestr.eof()){
            ++M;
            ++M;
            linestr >> line;
        }

        while (getline(in, line))
            ++N;
        in.close();
    }
    {//count qM, finish M
        string line = "1#";
        ifstream qin(qinFile);
        if (!qin.is_open()){
            std::cout << qinFile << " failed to open!\n";
            exit(1);
        }
        while(line[1] == '#')
            getline(qin, line);
        stringstream linestr;
    
        linestr.str(line);
        linestr.clear();
        for (int i = 0; i<9; ++i)
            linestr >> line;
        qM = 0;
        while (!linestr.eof()){
            ++qM;
            ++qM;
            linestr >> line;
        }
        M +=qM;
        int qN = 0;
        while (getline(qin, line)){
            ++qN;
        }
        if (qN != N){
            std::cout << "Query file and input file have different numbers of sites. Query has " << qN << ". Panel has " << N << std::endl;
            exit(1);
        }
        qin.close();
    }
    bool *temp = new bool[(long long)M * N];
    panel = new bool*[M];
    for (int i = 0; i<M; i++){
        panel[i] = &(temp[i*N]);
    }
    string line = "##", qline = "##";
    ifstream in (inFile);
    ifstream qin(qinFile);
    stringstream linestr, qlinestr;
    int x = 0;
    char y = 0;

    while (line[1] == '#')
        getline(in, line);
    while (qline[1] == '#')
        getline(qin, qline);
    for(int j = 0; j<N; ++j){
        getline(in, line);
        getline(qin, qline);
        linestr.str(line);
        linestr.clear();
        qlinestr.str(qline);
        qlinestr.clear();
        for (int i = 0; i<9; ++i){
            linestr >> line;
            qlinestr >> qline;
        }
        for (int i = 0; i<(M-qM)/2; ++i){
            linestr >> x >> y;
            panel[i*2][j] = (bool)x;
            linestr >> x;
            panel[i*2 + 1][j] = (bool)x;
        }
        for (int i = (M-qM)/2; i < M/2; ++i){
            qlinestr >> x >> y;
            panel[i*2][j] = (bool)x;
            qlinestr >> x;
            panel[i*2+1][j] = (bool)x;
        }
    }
    in.close();
    qin.close();
}

void PBWT(bool **panel, int **prefix, int **divergence,
        int **u, int **v, int M, int N){

    for (int i = 0; i<M; ++i){
        prefix[i][0] = i;
        divergence[i][0] = 0;
    }
    for (int k = 0; k<N; ++k){
        int u2 = 0, v2 = 0, p = k+1, q = k+1;
        std::vector<int> a,b,d,e;
        for (int i = 0; i<M; ++i){
            u[i][k] = u2;
            v[i][k] = v2;
            if (divergence[i][k] > p) { p = divergence[i][k];}
            if (divergence[i][k] > q) { q = divergence[i][k];}
            if (!panel[prefix[i][k]][k]){
                a.push_back(prefix[i][k]);
                d.push_back(p);
                ++u2;
                p = 0;
            }
            else{
                b.push_back(prefix[i][k]);
                e.push_back(q);
                ++v2;
                q = 0;
            }
        }
        for (int i = 0; i<M; ++i){
            v[i][k] += a.size();
            if (i < a.size()){
                prefix[i][k+1] = a[i];
                divergence[i][k+1] = d[i];
            }
            else{
                prefix[i][k+1] = b[i-a.size()];
                divergence[i][k+1] = e[i-a.size()];
            }
        }
    }
}

void benchmark(bool **panel, int **prefix, int **divergence, 
        int **u, int **v, int M, int qM, int N, std::ofstream & out){

    out << "Walltime\tCPUtime\tSize\tSolSpaceSize"
        << "\tLengthMaxMPSCLength\tLengthMaxMPSC\n";

    for (int i = M-qM; i<M; ++i){
        double oldtime = 0, oldcputime = 0;
        oldtime = get_wall_time();
        oldcputime = get_cpu_time();

        int Oid = i;
        int *t = new int[N+1];

        //get PPA positions
        t[0] = Oid;
        for (int j = 0; j<N; ++j)
            t[j+1] = (panel[Oid][j]) ? v[t[j]][j] : u[t[j]][j];

        //get divergence values and set maximal matches
        int bd=0, *mind = new int[N+2];
        mind[N+1] = N;
        std::vector<int> setMaxStart, setMaxEnd;
        for (int j = N; j>=0; --j){
            bd = j;
            if (t[j] != M-1)
                bd = divergence[t[j]+1][j];

            mind[j] = std::min(divergence[t[j]][j], bd);
            if (mind[j] < mind[j+1]){
                //record set maximal match found
                setMaxStart.push_back(mind[j]);
                setMaxEnd.push_back(j-1);
            }
        }

        //get leftmost mpsc
        //stores starting point of |C|-i-1 th positional substring of leftmost mpsc
        std::vector<int> l;
        int k = mind[N], oldk = N;
        while (k != oldk){
            l.push_back(k);
            oldk = k;
            k = mind[k];
        }
        if (k != 0) {
            //skip this haplotype because no mpsc of it by M exists
            out << "-\t-\t-\t-\t-\t-\n";
            continue;
        }

        //get rightmost MPSC
        //b is the opposite of the divergence array
        //b[i] is the end of the longest match starting at position i, i-1 if none
        //mind[i] is the start of the longest match ending at position i-1, i if none
        int *b = new int[N];
        //calculation of b array, d is first site calculated so far
        //k is current site
        int d = N; 
        k = N;
        while (d != 0){
            while (mind[k] < d)
                b[--d] = k-1;
            --k;
        }

        //ending point of i th positional substring in rightmost MPSC
        std::vector<int> r;
        //k = first site not covered so far
        k = 0;
        while (k != N){
            r.push_back(b[k]);
            k = b[k]+1;
        }

        //length maximal MPSC output
        int mpscSize = r.size(),
            numSMMs = setMaxStart.size();
        std::vector<int> *SMMContains = new std::vector<int>[mpscSize];
        if (numSMMs == 1){
            SMMContains[0].push_back(0);
        }
        else{
            //identifying set maximal matches that fully contain each fragile region
            //SMMContains[i] is an array of the set maximal matches that contain the i-th required region
            //SMMContains[i] is sorted by first site of each set maximal match
            //the 0-th and |C|-1 th required regions are only fully contained by the first and last set maximal
            //matches because they are the only SMMs that contain sites 0 and N-1 respectively
            SMMContains[0].push_back(0);
            SMMContains[mpscSize-1].push_back(numSMMs-1);
            //NOTE: REMEMBER, setMaxStart[i] is the start of the numSMMs-i-1 th set maximal match
            int currentRegion = 1;
            int currRegStart = r[currentRegion-1]+1;
            int currRegEnd = l[mpscSize-currentRegion-2]-1;
            //std::cout << setMaxStart[1] << ',' << setMaxEnd[1] << std::endl;
            for (int j = numSMMs-2; j>0; --j){
                //j is the index of the i-th SMM in setMaxStart for i = 1 to numSMMs-2. i = numSMMs-j-1
                //l[mpscSize-currentRegion-2]-1 is the last site in the currentRegion required region
                //r[currentRegion-1]+1 is the first site in the currentRegion required region
                while(currRegStart < setMaxStart[j]){
                    ++currentRegion;
                    if (currentRegion > mpscSize-2)
                        break;
                    //if currentRegion = 0, start = 0
                    currRegStart = r[currentRegion-1]+1;
                    currRegEnd = l[mpscSize-currentRegion-2]-1;
                    //std::cout << "Region " << currentRegion << ", j:" << j << std::endl;
                }
                if (currentRegion > mpscSize-2)
                    break;
                if (setMaxEnd[j] >= currRegEnd)
                    SMMContains[currentRegion].push_back(numSMMs-j-1);
            }
        }
        /*
           std::cout << "Num SMMs per required region for haplotype " << Oid << '\n';
           for (int j = 0; j<mpscSize; ++j){
           std::cout << SMMContains[j].size() << '\n';
           }*/

        //holds length of longest path from node to sink
        int **mpscLength = new int*[mpscSize];
        for (int j = 0; j<mpscSize; ++j)
            mpscLength[j] = new int[SMMContains[j].size()];
        //calculate lengths
        //length of longest path of last SMM is the length of the last SMM
        mpscLength[mpscSize-1][0] = setMaxEnd[0] - setMaxStart[0] + 1;
        for (int j = mpscSize-2; j>=0; --j){
            //calculate length of longest paths of j th required region
            //using lengths of longest paths of j+1 th required region
            //prevMax is max of lengths of paths of nodes 0 through prevMaxIndex (inclusive)
            int prevMax = -1, prevMaxIndex = -1;
            for (int k = 0; k<SMMContains[j].size(); ++k){
                while (prevMaxIndex+1 < SMMContains[j+1].size() && 
                        setMaxEnd[numSMMs - SMMContains[j][k] - 1] >= 
                        setMaxStart[numSMMs - SMMContains[j+1][prevMaxIndex+1] - 1]-1)
                    if (mpscLength[j+1][++prevMaxIndex] > prevMax){
                        prevMax = mpscLength[j+1][prevMaxIndex];
                    }
                mpscLength[j][k] = prevMax + setMaxEnd[numSMMs - SMMContains[j][k] - 1] - setMaxStart[numSMMs - SMMContains[j][k] - 1] + 1;
            }
        }
        //mpscLength[0][0] now contains the length of the length maximal MPSC
        int *lengthMAXMPSC = new int[mpscSize];
        //backtracking step to find length max mpsc
        for (int j = 0; j < mpscSize; ++j){
            int max = mpscLength[j][0];
            lengthMAXMPSC[j] = SMMContains[j][0];
            for (int k = 0; k<SMMContains[j].size(); ++k)
                if (mpscLength[j][k] > max)
                    lengthMAXMPSC[j] = SMMContains[j][k];
        }


        //calculate solution space size (set maximal match only MPSCs)
        long double **numPaths = new long double*[mpscSize];
        for (int j = 0; j<mpscSize; ++j)
            numPaths[j] = new long double[SMMContains[j].size()];

        //calculate numPaths
        //numPaths from last SMM to itself is 1
        numPaths[mpscSize-1][0] = 1;
        for (int j = mpscSize-2; j>=0; --j){
            //calculate numPaths for j th required region
            //use numPaths of j+1 th required region
            //prevNumPaths is numPaths of previous node evaluated
            //newPaths are numPaths new to current node
            //prevMaxIndex is index of last node in j+1 th region that has an edge to it from the previous evaluated node
            long double prevNumPaths = 0;
            long long prevMaxIndex = -1;
            for (int k = 0; k<SMMContains[j].size(); ++k){
                while (prevMaxIndex+1 < SMMContains[j+1].size() &&
                        setMaxEnd[numSMMs - SMMContains[j][k] - 1] >=
                        setMaxStart[numSMMs - SMMContains[j+1][prevMaxIndex+1] - 1]-1)
                    prevNumPaths += numPaths[j+1][++prevMaxIndex];
                numPaths[j][k] = prevNumPaths;
            }
        }
        //numPaths[0][0] now contains the size of the set maximal match only solution space

        oldcputime = get_cpu_time() - oldcputime;
        oldtime = get_wall_time() - oldtime;

        out << oldtime << '\t' << oldcputime  << '\t'
            << mpscSize << '\t' << numPaths[0][0] << '\t'
            << mpscLength[0][0] << '\t';
        for(int j = 0; j<mpscSize-1; ++j)
            out << "[" << setMaxStart[numSMMs - lengthMAXMPSC[j] - 1] << ',' << setMaxEnd[numSMMs - lengthMAXMPSC[j] - 1] << "],";
        out << "[" << setMaxStart[numSMMs - lengthMAXMPSC[mpscSize - 1] - 1] 
            << ',' << setMaxEnd[numSMMs - lengthMAXMPSC[mpscSize - 1] - 1] << "]\n";
    }
}

int main(int argc, char *argv[]){
    std::string in = "panel.vcf", qin = "query.vcf", out = "output.txt";
    bool help = false;
    for (int i = 1; !help && i<argc; ++i){
        switch(argv[i][1]){
            case 'r':
            case 'R':
                in = argv[++i];
                break;
            case 'q':
            case 'Q':
                qin = argv[++i];
                break;
            case 'o':
            case 'O':
                out = argv[++i];
                break;
            default:
                std::cout << argv[i] << " is an invalid argument\n\n";
                help = true;
                break;
        }
    }
    if (help){
        std::cout << "This program takes VCF files as input. The input VCF files must be in the format specified in aboutInput.txt. It takes as input two VCF files, one VCF file stores a reference panel. The other stores a set of haplotypes to thread through the reference panel. These VCF files must have the same number of sites and the i-th site in each VCF should refer to the same site in the genome for all i.\n\n"
            << "Given these VCF files, the program outputs for each query haplotype, the size of the Minimal Positional Substring Cover of the query haplotype by the reference panel, the size of the set maximal match only MPSC solution space, the length maximal MPSC length, and a Length Maximal MPSC of the query by the reference panel. It also outputs the time taken to obtain and output the above information (given a PBWT of the reference panel). The algorithms implemented here are presented in DOI:.\n\n"
            << "Compile with std=c++11 or higher. The following command may be used:\n\n\nSample Usage:\n\n"
            << "After compilation, you can run the program on the sample vcfs using the following command:\n\noptions:\n"
            << "-r, input VCF file for the reference panel, default is panel.vcf\n"
            << "-q, input VCF file for the panel of query haplotypes to search against the reference panel, default is query.vcf\n"
            << "-o, output file for statistics, default is output.txt\n"
            << "-h, help. output this screen\n";
        return 0;
    }

    std::ofstream outS(out);
    if (!outS.is_open()){
        std::cout << out << " failed to open!" << std::endl;
        return 1;
    }

    int M, N, qM;
    bool **panel; 
    int **prefix, **divergence, **u, **v;

    double oldtime = 0, oldcputime = 0;

    //read VCF files
    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    ReadVCF(in, qin, M, qM, N, panel);
    std::cout << "M: " << M-qM << "\nN: " << N << "\nqM: " << qM << std::endl;
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "Reading panel took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;

    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    prefix = new int*[M];
    divergence = new int*[M];
    u = new int*[M];
    v = new int*[M];
    int *temp1 = new int[(long long)(M)*(N+1)];
    int *temp2 = new int[(long long)(M)*(N+1)];
    int *temp3 = new int[(long long)(M)*(N)];
    int *temp4 = new int[(long long)(M)*(N)];
    for (long long i = 0; i<M; i++){
        prefix[i] = &(temp1[i*(N+1)]);
        divergence[i] = &(temp2[i*(N+1)]);
        u[i] = &(temp3[i*(N)]);
        v[i] = &(temp4[i*(N)]);
    }
    PBWT(panel, prefix, divergence, u, v, M, N);
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "Building PBWT took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;

    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    benchmark(panel, prefix, divergence, u, v, M, qM, N, outS);
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "MPSC Benchmark took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;

    outS.close();
    return 0;
}
