/*
 * 1. Build PBWT of full panel except 1000, M-1000 x N
 * 2. for size in {1000,2000,3000,4000,5000,6000,7000,8000,9000,N}
 * 3. do length maximal mpsc benchmark 1 on panel of M-1000 x size
 *      find fragile regions
 *      build graph
 *      find length maximal mpsc
 *      count 
 *          edges
 *          nodes
 *          nodes per region
 */

#include<iostream>
#include<iomanip>
#include<fstream>
#include<sstream>
#include<string>
#include<vector>
#include<algorithm>
#include<cstdlib>

#define M 860022
#define N 9793

//-----------------------------------------------------------------------------
//source for time: https://stackoverflow.com/questions/17432502/how-can-i-measure-cpu-time-and-wall-clock-time-on-both-linux-windows
#include<time.h>
#include<sys/time.h>
double get_wall_time(){
    struct timeval time;
    if (gettimeofday(&time,NULL)){
        //  Handle error
        std::cerr<< "ERROR reading real time\n";
    }
    return (double)time.tv_sec + (double)time.tv_usec * .000001;
}
double get_cpu_time(){
    return (double)clock() / CLOCKS_PER_SEC;
}
//-----------------------------------------------------------------------------

void ReadVCF(const char *, bool[][N]);
void PBWT(bool[][N], int[][N+1], int [][N+1], int[][N], int [][N], int);
void benchmark(bool panel[][N], int [][N+1], int [][N+1], int [][N], int [][N]);

void ReadVCF(const char* inFile, bool panel[][N]){
    using namespace std;
    string line = "1";
    ifstream in(inFile);
    if (!in.is_open()){
        std::cout << inFile << " failed to open!" << std::endl;
        exit(1);
    }
    stringstream linestr;
    int x = 0;
    char y = 0;

    while (line[0] != '#' || line[1] != 'C')
        getline(in, line);
    for(int j = 0; j<N; ++j){
        getline(in, line);
        linestr.str(line);
        linestr.clear();
        for (int i = 0; i<9; ++i)
            linestr >> line;
        for (int i = 0; i<M/2; ++i){
            linestr >> x >> y;
            panel[i*2][j] = (bool)x;
            linestr >> x;
            panel[i*2 + 1][j] = (bool)x;
        }
    }
    in.close();
}

void PBWT(bool panel[][N], int prefix[][N+1], int divergence[][N+1],
        int u[][N], int v[][N], int num){
    int order[M];
    /*{
        //get order
        std::ifstream orderIn("order.txt");
        if (!orderIn.is_open()){
            std::cout << "order.txt failed to open!" << std::endl;
            exit(1);
        }
        for (int i = 0; i<M; ++i)
            orderIn >> order[i];
        orderIn.close();
    }*/
    for (int i = 0; i<num; ++i)
        order[i] = i;

    for (int i = 0; i<num; ++i){
        prefix[i][0] = order[i];
        divergence[i][0] = 0;
    }
    for (int k = 0; k<N; ++k){
        int u2 = 0, v2 = 0, p = k+1, q = k+1;
        std::vector<int> a,b,d,e;
        for (int i = 0; i<num; ++i){
            u[i][k] = u2;
            v[i][k] = v2;
            if (divergence[i][k] > p) { p = divergence[i][k];}
            if (divergence[i][k] > q) { q = divergence[i][k];}
            if (!panel[prefix[i][k]][k]){
                a.push_back(prefix[i][k]);
                d.push_back(p);
                ++u2;
                p = 0;
            }
            else{
                b.push_back(prefix[i][k]);
                e.push_back(q);
                ++v2;
                q = 0;
            }
        }
        for (int i = 0; i<num; ++i){
            v[i][k] += a.size();
            if (i < a.size()){
                prefix[i][k+1] = a[i];
                divergence[i][k+1] = d[i];
            }
            else{
                prefix[i][k+1] = b[i-a.size()];
                divergence[i][k+1] = e[i-a.size()];
            }
        }
    }
}

void benchmark(bool panel[][N], int prefix[][N+1], int divergence[][N+1], 
        int u[][N], int v[][N]){
    std::ofstream timeOut("time_out.txt"), solSpaOut("solSpace_out.txt"), lengthOut("length_out.txt"), mpscOut("mpsc_out.txt"), reqRegOut("reqRegOut.txt"), smmOut("smmOut.txt");
    int order[M];
    /*{
        //get order
        std::ifstream orderIn("order.txt");
        if (!orderIn.is_open()){
            std::cout << "order.txt failed to open!" << std::endl;
            exit(1);
        }
        for (int i = 0; i<M; ++i)
            orderIn >> order[i];
        orderIn.close();
    }*/
    for (int i = 0; i<M; ++i)
        order[i] = i;

    for (int i = 0; i<M; ++i){
        double oldtime = 0, oldcputime = 0;
        oldtime = get_wall_time();
        oldcputime = get_cpu_time();
        //benchmark for order[i]
        int Oid = order[i];
        int *t = new int[N+1];

        //get PPA positions
        t[0] = Oid;
        for (int j = 0; j<N; ++j)
            t[j+1] = (panel[Oid][j]) ? v[t[j]][j] : u[t[j]][j];

        //get divergence values and set maximal matches
        int bd=0, *mind = new int[N+2];
        mind[N+1] = N;
        std::vector<int> setMaxStart, setMaxEnd;
        for (int j = N; j>=0; --j){
            bd = j;
            if (t[j] != M-1)
                bd = divergence[t[j]+1][j];

            mind[j] = std::min(divergence[t[j]][j], bd);
            if (mind[j] < mind[j+1]){
                //record set maximal match found
                setMaxStart.push_back(mind[j]);
                setMaxEnd.push_back(j-1);
            }
        }

        //get leftmost mpsc
        //stores starting point of |C|-i-1 th positional substring of leftmost mpsc
        std::vector<int> l;
        int k = mind[N], oldk = N;
        while (k != oldk){
            l.push_back(k);
            oldk = k;
            k = mind[k];
        }
        if (k != 0) {
            //skip this haplotype because no mpsc of it by M exists
            std::cout << Oid << ":No MPSC\n";
            continue;
        }

        //get rightmost MPSC
        //b is the opposite of the divergence array
        //b[i] is the end of the longest match starting at position i, i-1 if none
        //mind[i] is the start of the longest match ending at position i-1, i if none
        int *b = new int[N];
        //calculation of b array, d is first site calculated so far
        //k is current site
        int d = N; 
        k = N;
        while (d != 0){
            while (mind[k] < d)
                b[--d] = k-1;
            --k;
        }

        //ending point of i th positional substring in rightmost MPSC
        std::vector<int> r;
        //k = first site not covered so far
        k = 0;
        while (k != N){
            r.push_back(b[k]);
            k = b[k]+1;
        }

        //length maximal MPSC output
        int mpscSize = r.size(),
            numSMMs = setMaxStart.size();
        std::vector<int> *SMMContains = new std::vector<int>[mpscSize];
        if (numSMMs == 1){
            SMMContains[0].push_back(0);
        }
        else{
            //identifying set maximal matches that fully contain each fragile region
            //SMMContains[i] is an array of the set maximal matches that contain the i-th required region
            //SMMContains[i] is sorted by first site of each set maximal match
            //the 0-th and |C|-1 th required regions are only fully contained by the first and last set maximal
            //matches because they are the only SMMs that contain sites 0 and N-1 respectively
            SMMContains[0].push_back(0);
            SMMContains[mpscSize-1].push_back(numSMMs-1);
            //NOTE: REMEMBER, setMaxStart[i] is the start of the numSMMs-i-1 th set maximal match
            int currentRegion = 1;
            int currRegStart = r[currentRegion-1]+1;
            int currRegEnd = l[mpscSize-currentRegion-2]-1;
            //std::cout << setMaxStart[1] << ',' << setMaxEnd[1] << std::endl;
            for (int j = numSMMs-2; j>0; --j){
                //j is the index of the i-th SMM in setMaxStart for i = 1 to numSMMs-2. i = numSMMs-j-1
                //l[mpscSize-currentRegion-2]-1 is the last site in the currentRegion required region
                //r[currentRegion-1]+1 is the first site in the currentRegion required region
                while(currRegStart < setMaxStart[j]){
                    ++currentRegion;
                    if (currentRegion > mpscSize-2)
                        break;
                    //if currentRegion = 0, start = 0
                    currRegStart = r[currentRegion-1]+1;
                    currRegEnd = l[mpscSize-currentRegion-2]-1;
                    //std::cout << "Region " << currentRegion << ", j:" << j << std::endl;
                }
                if (currentRegion > mpscSize-2)
                    break;
                if (setMaxEnd[j] >= currRegEnd)
                    SMMContains[currentRegion].push_back(numSMMs-j-1);
            }
        }
        /*
           std::cout << "Num SMMs per required region for haplotype " << Oid << '\n';
           for (int j = 0; j<mpscSize; ++j){
           std::cout << SMMContains[j].size() << '\n';
           }*/

        //holds length of longest path from node to sink
        int **mpscLength = new int*[mpscSize];
        for (int j = 0; j<mpscSize; ++j)
            mpscLength[j] = new int[SMMContains[j].size()];
        //calculate lengths
        //length of longest path of last SMM is the length of the last SMM
        mpscLength[mpscSize-1][0] = setMaxEnd[0] - setMaxStart[0] + 1;
        for (int j = mpscSize-2; j>=0; --j){
            //calculate length of longest paths of j th required region
            //using lengths of longest paths of j+1 th required region
            //prevMax is max of lengths of paths of nodes 0 through prevMaxIndex (inclusive)
            int prevMax = -1, prevMaxIndex = -1;
            for (int k = 0; k<SMMContains[j].size(); ++k){
                while (prevMaxIndex+1 < SMMContains[j+1].size() && 
                        setMaxEnd[numSMMs - SMMContains[j][k] - 1] >= 
                        setMaxStart[numSMMs - SMMContains[j+1][prevMaxIndex+1] - 1]-1)
                    if (mpscLength[j+1][++prevMaxIndex] > prevMax){
                        prevMax = mpscLength[j+1][prevMaxIndex];
                    }
                mpscLength[j][k] = prevMax + setMaxEnd[numSMMs - SMMContains[j][k] - 1] - setMaxStart[numSMMs - SMMContains[j][k] - 1] + 1;
            }
        }
        //mpscLength[0][0] now contains the length of the length maximal MPSC
        int *lengthMAXMPSC = new int[mpscSize];
        //backtracking step to find length max mpsc
        for (int j = 0; j < mpscSize; ++j){
            int max = mpscLength[j][0];
            lengthMAXMPSC[j] = SMMContains[j][0];
            for (int k = 0; k<SMMContains[j].size(); ++k)
                if (mpscLength[j][k] > max)
                    lengthMAXMPSC[j] = SMMContains[j][k];
        }


        //calculate solution space size (set maximal match only MPSCs)
        long double **numPaths = new long double*[mpscSize];
        for (int j = 0; j<mpscSize; ++j)
            numPaths[j] = new long double[SMMContains[j].size()];

        //calculate numPaths
        //numPaths from last SMM to itself is 1
        numPaths[mpscSize-1][0] = 1;
        for (int j = mpscSize-2; j>=0; --j){
            //calculate numPaths for j th required region
            //use numPaths of j+1 th required region
            //prevNumPaths is numPaths of previous node evaluated
            //newPaths are numPaths new to current node
            //prevMaxIndex is index of last node in j+1 th region that has an edge to it from the previous evaluated node
            long double prevNumPaths = 0;
            long long prevMaxIndex = -1;
            for (int k = 0; k<SMMContains[j].size(); ++k){
                while (prevMaxIndex+1 < SMMContains[j+1].size() &&
                        setMaxEnd[numSMMs - SMMContains[j][k] - 1] >=
                        setMaxStart[numSMMs - SMMContains[j+1][prevMaxIndex+1] - 1]-1)
                    prevNumPaths += numPaths[j+1][++prevMaxIndex];
                numPaths[j][k] = prevNumPaths;
            }
        }
        //numPaths[0][0] now contains the size of the set maximal match only solution space

        oldcputime = get_cpu_time() - oldcputime;
        oldtime = get_wall_time() - oldtime;

        timeOut << Oid << '\t' << oldtime << '\t' << oldcputime  << '\n';
        solSpaOut << numPaths[0][0] << '\n';
        lengthOut << mpscLength[0][0] << '\n';
        for(int j = 0; j<mpscSize-1; ++j)
            mpscOut << "[" << setMaxStart[numSMMs - lengthMAXMPSC[j] - 1] << ',' << setMaxEnd[numSMMs - lengthMAXMPSC[j] - 1] << "],";
        mpscOut << "[" << setMaxStart[numSMMs - lengthMAXMPSC[mpscSize - 1] - 1] 
            << ',' << setMaxEnd[numSMMs - lengthMAXMPSC[mpscSize - 1] - 1] << "]\n";

        //Required Region output
        if (mpscSize <= 1)
            reqRegOut << "[0,"<<N-1<<"], MPSC Size = " << mpscSize << '\n';
        else{
            reqRegOut << "[0," << l[mpscSize-2]-1 << "]";
            for (int j = 1; j<mpscSize-1; ++j)
                reqRegOut << ",[" << r[j-1]+1 << ',' << l[mpscSize-j-2]-1 << "]";
            reqRegOut << ",[" << r[mpscSize-2]+1 << "," << N-1 << "]\n";
        }

        smmOut << mpscSize;
        for (int j = 0; j<mpscSize; ++j)
            smmOut << '\t' << SMMContains[j].size();
        smmOut << '\n';
    }
    timeOut.close();
    lengthOut.close();
    mpscOut.close();
    reqRegOut.close();
    smmOut.close();
}

int main(){
    bool panel[M][N];
    int prefix[M][N+1], divergence[M][N+1], u[M][N], v[M][N];
    double oldtime = 0, oldcputime = 0;

    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    ReadVCF("panelBritOnly.vcf.smooth.vcf", panel);
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "Reading panel took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;

    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    PBWT(panel, prefix, divergence, u, v, M);
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "Building PBWT took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;

    oldtime = get_wall_time();
    oldcputime = get_cpu_time();
    benchmark(panel, prefix, divergence, u, v);
    oldcputime = get_cpu_time() - oldcputime;
    oldtime = get_wall_time() - oldtime;
    std::cout << "Solution Space benchmark took " << oldtime << " wall seconds and " << oldcputime  << " cpu seconds." << std::endl;
    return 0;
}
