package main.java.pg;

import com.compomics.util.experiment.biology.PTMFactory;
import com.compomics.util.experiment.biology.Peptide;
import com.compomics.util.experiment.identification.matches.ModificationMatch;
import com.compomics.util.experiment.io.massspectrometry.MgfFileIterator;
import com.compomics.util.experiment.massspectrometry.*;
import main.java.PSMMatch.JPeptideSpectrumMatch;
import main.java.PSMMatch.JPeak;
import main.java.PSMMatch.JPeptide;
import main.java.PSMMatch.JSpectrum;
import uk.ac.ebi.jmzml.xml.io.MzMLUnmarshallerException;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
 * Read spectrum and filter according to peptide sequence
 */
public class SpectraInput {


    public static boolean isTargetPeptide2spectrumScored = false;

    /**
     * matched spectra
     */
    public static HashMap<String,MSnSpectrum> spectraMap = new HashMap<>();



    // used for testing
    public static void main(String[] args) throws IOException, MzMLUnmarshallerException {

        String mgf = "Ups1.mgf";
        String peptideSeq = "TPSAAYLWVGTGASEAEK";
        ArrayList<ModificationMatch> modificationMatches = new ArrayList<ModificationMatch>();
        PTMFactory ptmFactory = PTMFactory.getInstance();
        Peptide peptide = new Peptide(peptideSeq,modificationMatches);
        ArrayList<MSnSpectrum> msnSpectrums = readMSMS(mgf,peptide,10,true,10);
        System.out.println(msnSpectrums.size());

        for(MSnSpectrum msnSpectrum: msnSpectrums){
            System.out.println(msnSpectrum.getSpectrumTitle());

            JSpectrum jSpectrum = new JSpectrum();
            for (Peak p : msnSpectrum.getPeakList()) {
                JPeak jPeak = new JPeak(p.getMz(), p.getIntensity());
                jSpectrum.addRawPeak(jPeak);
                // System.out.println("ok");
            }
            jSpectrum.resetPeaks();
            jSpectrum.sortPeaksByMZ();

            JPeptideSpectrumMatch psm = new JPeptideSpectrumMatch();
            psm.setCalculatedMassToCharge(msnSpectrum.getPrecursor().getMz());
            psm.setCharge(msnSpectrum.getPrecursor().getPossibleCharges().get(0).value);
            psm.setExperimentalMassToCharge(msnSpectrum.getPrecursor().getMz());
            psm.setRank(1);
            psm.setSpectrumID(msnSpectrum.getSpectrumTitle());
            psm.setPepSeq(peptide.getSequence());



        }

    }


    public static ArrayList<MSnSpectrum> readMSMS(String spectraFile, Peptide peptide, double tol, boolean isPpm, int minPeaks) throws IOException {

        // save the filtered MS/MS spectra
        ArrayList<MSnSpectrum> msnSpectrums = new ArrayList<MSnSpectrum>();
        //SpectrumFactory spectrumFactory = SpectrumFactory.getInstance();
        File mgfFile = new File(spectraFile);
        //spectrumFactory.addSpectra(mgfFile, null);

        MgfFileIterator mgfFileIterator = new MgfFileIterator(mgfFile);

        // peptide mass
        double peptideMass = peptide.getMass();
        int i=0;
        while(mgfFileIterator.hasNext()){
        //for(int i=0;i<totalMSMS;i++){
            i++;
            MSnSpectrum spectrum = mgfFileIterator.next();
            System.out.println(spectrum.getNPeaks());
            //MSnSpectrum spectrum = (MSnSpectrum) spectrumFactory.getSpectrum(msfileName, String.valueOf(i));
            // if the peak number is less than minPeaks, then this spectra will be removed
            if(spectrum.getNPeaks()<=minPeaks){
                continue;
            }
            int charge = spectrum.getPrecursor().getPossibleCharges().get(0).value;
            double mass = spectrum.getPrecursor().getMass(charge);
            if(isPpm){
                if( (1.0e6)*1.0*Math.abs(peptideMass-mass)/peptideMass <= tol){
                    msnSpectrums.add(spectrum);
                }
            }

            if( ((i+1)%10000)==0){
                System.out.println("Finished:"+i+","+msnSpectrums.size());

            }

        }

        return(msnSpectrums);
    }

    // When taking multiple peptides as input, build a peptide index for fast searching
    public static void readMSMSfast(String spectraFile, ArrayList<PeptideInput> peptideInputs) throws IOException, InterruptedException {


        // bin=0.1Da

        ConcurrentHashMap<Integer,ArrayList<JPeptide>> pIndex = new ConcurrentHashMap<>();
        for(PeptideInput peptideInput: peptideInputs){
            JPeptide jPeptide = peptideInput.jPeptide;
            int rmass = (int) Math.round(10.0*jPeptide.peptide.getMass());
            if(pIndex.containsKey(rmass)){
                pIndex.get(rmass).add(jPeptide);
            }else{
                ArrayList<JPeptide> jps = new ArrayList<>();
                jps.add(jPeptide);
                pIndex.put(rmass,jps);
            }
            for(JPeptide jPeptide2: peptideInput.getPtmIsoforms()){
                rmass = (int) Math.round(10.0*jPeptide2.peptide.getMass());
                if(pIndex.containsKey(rmass)){
                    pIndex.get(rmass).add(jPeptide2);
                }else{
                    ArrayList<JPeptide> jps = new ArrayList<>();
                    jps.add(jPeptide2);
                    pIndex.put(rmass,jps);
                }
            }
        }

        System.out.println("Build target peptide index done.");

        ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
        ExecutorService fixedThreadPoolScore = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());

        //SpectrumFactory spectrumFactory = SpectrumFactory.getInstance();
        File mgfFile = new File(spectraFile);
        //spectrumFactory.addSpectra(mgfFile, null);

        MgfFileIterator mgfFileIterator = new MgfFileIterator(mgfFile);


        int i=0;
        MSnSpectrum spectrum = null;
        while(mgfFileIterator.hasNext()){
            //for(int i=0;i<totalMSMS;i++){
            i++;
            spectrum = mgfFileIterator.next();
            spectrum.setSpectrumTitle(String.valueOf(i));
            //MSnSpectrum spectrum = (MSnSpectrum) spectrumFactory.getSpectrum(msfileName, String.valueOf(i));
            // if the peak number is less than minPeaks, then this spectra will be removed
            if(spectrum.getNPeaks()<=CParameter.minPeaks){
                continue;
            }
            // 2018-04-19
            // It's possible that the spectrum doesn't have charge information.
            fixedThreadPoolScore.execute(new FindCandidateSpectraAndScoreWorker(pIndex,spectrum));

            if( (i%10000)==0){
                System.out.println("Finished:"+i);

            }

        }

        fixedThreadPoolScore.shutdown();
        fixedThreadPoolScore.awaitTermination(Long.MAX_VALUE, TimeUnit.HOURS);


        ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
        int npsms = 0;

        // Some spectra may not have charge state information. Save the charge state information used for scoring
        // for these spectra.
        HashMap<String,Integer> spectrumID2charge = new HashMap<>();
        HashSet<String> matchSpectra = new HashSet<>(128);
        for(PeptideInput peptideInput: peptideInputs){
            JPeptide jPeptide = peptideInput.jPeptide;
            if(!jPeptide.spectraIndexs.isEmpty()){
                matchSpectra.addAll(jPeptide.spectraIndexs);
                npsms = npsms + jPeptide.spectraIndexs.size();
                for(String spectrumID : jPeptide.spectrumID2charge.keySet()){
                    spectrumID2charge.put(spectrumID,jPeptide.spectrumID2charge.get(spectrumID));
                }
            }
            for(JPeptide jPeptide2: peptideInput.getPtmIsoforms()){
                if(!jPeptide2.spectraIndexs.isEmpty()){
                    matchSpectra.addAll(jPeptide2.spectraIndexs);
                    npsms = npsms + jPeptide2.spectraIndexs.size();
                    for(String spectrumID : jPeptide2.spectrumID2charge.keySet()){
                        spectrumID2charge.put(spectrumID,jPeptide2.spectrumID2charge.get(spectrumID));
                    }
                }
            }
        }

        System.out.println("Matched spectra: "+matchSpectra.size());
        System.out.println("Matched PSMs: "+npsms);
        if(spectrumID2charge.size()>=1) {
            System.out.println("Spectra without charge information: " + spectrumID2charge.size());
        }
        ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
        mgfFileIterator = new MgfFileIterator(mgfFile);

        i=0;
        while(mgfFileIterator.hasNext()){
            //for(int i=0;i<totalMSMS;i++){
            i++;
            spectrum = mgfFileIterator.next();

            spectrum.setSpectrumTitle(String.valueOf(i));
            //MSnSpectrum spectrum = (MSnSpectrum) spectrumFactory.getSpectrum(msfileName, String.valueOf(i));
            // if the peak number is less than minPeaks, then this spectra will be removed
            if(matchSpectra.contains(spectrum.getSpectrumTitle())){
                if(spectrum.getPrecursor().getPossibleCharges().isEmpty()){
                    // no charge information
                    // We need to change the precursor of this spectrum
                    ArrayList<Charge> possibleCharges = new ArrayList<>();
                    int charge = spectrumID2charge.get(spectrum.getSpectrumTitle());
                    Charge C_charge = new Charge(Charge.PLUS,charge);
                    possibleCharges.add(C_charge);
                    Precursor precursor = new Precursor(0.0,spectrum.getPrecursor().getMz(),possibleCharges);
                    spectrum.setPrecursor(precursor);
                }
                spectraMap.put(spectrum.getSpectrumTitle(),spectrum);
            }

            if( (i%10000)==0){
                System.out.println("Finished:"+i);

            }

        }

        pIndex = null;
        mgfFileIterator = null;
        matchSpectra = null;

        System.out.println("Saved spectra: "+spectraMap.size());

        System.out.println("Get matched spectra done!");

        isTargetPeptide2spectrumScored = true;


    }


    public static void readMSMSdb2(String sqldb, ArrayList<PeptideInput> peptideInputs) throws IOException, SQLException, ClassNotFoundException, InterruptedException {

        Connection connection = DriverManager.getConnection("jdbc:sqlite:" + sqldb);
        PreparedStatement pstmt = connection.prepareStatement("select id,spectrum from msmsdb where mass >= ? and mass <= ?");

        int matchSpectra = 0;

        PeptideInput peptideInput;
        double peptideMass;
        double pepIsoformMass;

        // don't use multiple threads here. Multiple threads will cause problem here.
        ExecutorService fixedThreadPoolScore = Executors.newFixedThreadPool(1);

        ConcurrentHashMap<String,MSnSpectrum> matchedSpectra = new ConcurrentHashMap<>();

        int psize = peptideInputs.size();
        for (int pindex=0;pindex < psize;pindex++) {
            peptideInput = peptideInputs.get(pindex);
            // go through each peptide isoform in PeptideInput
            peptideMass = peptideInput.jPeptide.peptide.getMass();

            double massRange[] = getRangeOfMass(peptideMass,CParameter.tol,CParameter.tolu);
            pstmt.setDouble(1,massRange[0]);
            pstmt.setDouble(2,massRange[1]);

            ResultSet rs    = pstmt.executeQuery();
            ArrayList<MSnSpectrum> rsSpectra = getSpectraFromSQL(rs,CParameter.minPeaks);
            for(MSnSpectrum mSnSpectrum: rsSpectra){

                fixedThreadPoolScore.execute(new Peptide2SpectrumScoreWorker(peptideInput.jPeptide,mSnSpectrum,matchedSpectra));

            }

            matchSpectra = matchSpectra + rsSpectra.size();

            for (int j = 0; j < peptideInput.getPtmIsoforms().size(); j++) {
                pepIsoformMass = peptideInput.getPtmIsoforms().get(j).peptide.getMass();
                double massRange2[] = getRangeOfMass(pepIsoformMass,CParameter.tol,CParameter.tolu);
                pstmt.setDouble(1,massRange2[0]);
                pstmt.setDouble(2,massRange2[1]);

                rs    = pstmt.executeQuery();
                ArrayList<MSnSpectrum> rsSpectra2 = getSpectraFromSQL(rs,CParameter.minPeaks);
                for(MSnSpectrum mSnSpectrum: rsSpectra2){
                    fixedThreadPoolScore.execute(new Peptide2SpectrumScoreWorker(peptideInput.getPtmIsoforms().get(j),mSnSpectrum, matchedSpectra));

                }

                matchSpectra = matchSpectra + rsSpectra2.size();
            }
        }

        fixedThreadPoolScore.shutdown();
        fixedThreadPoolScore.awaitTermination(Long.MAX_VALUE, TimeUnit.HOURS);



        for(String title:matchedSpectra.keySet()){
            spectraMap.put(title,matchedSpectra.get(title));
        }

        //matchedSpectra.clear();

        System.out.println("Matched spectra: "+spectraMap.size());
        System.out.println("Matched PSMs: "+matchSpectra);

        pstmt.close();
        connection.close();

        matchedSpectra = null;

        isTargetPeptide2spectrumScored = true;

    }


    /**
     * Calculate the left and right values for specified tol.
     * @param pepmass peptide mass
     * @param tol   tol
     * @param isPpm true if the unit of tol is ppm.
     * @return
     */
    public static double[] getRangeOfMass(double pepmass, double tol, boolean isPpm){
        double [] massRange = new double [2];
        if(isPpm){
            double x = 1.0*(tol * pepmass) / (1.0e6);
            massRange[0] = pepmass-x;
            massRange[1] = pepmass+x;

        }else{
            massRange[0] = pepmass-tol;
            massRange[1] = pepmass+tol;
        }
        return massRange;

    }



    public static ArrayList<MSnSpectrum> getSpectraFromSQL(ResultSet rs, int minPeaks) throws SQLException, IOException, ClassNotFoundException {
        ArrayList<MSnSpectrum> spectrums = new ArrayList<>();

        while (rs.next()) {

            byte buf[] = rs.getBytes("spectrum");
            String id = rs.getString("id");
            ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(buf));
            MSnSpectrum p = (MSnSpectrum) objectInputStream.readObject();
            if(p.getNPeaks() > minPeaks){
                p.setSpectrumTitle(id);
                spectrums.add(p);
            }

        }
        return spectrums;
    }


}
