/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.RNASeq.singlecell;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.sql.PreparedStatement;
import java.util.HashSet;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import javax.json.Json;
import javax.json.JsonArray;
import javax.json.JsonObject;
import javax.json.JsonReader;
import org.rhwlab.db.MySql;

/**
 *
 * @author gevirl
 */
public class Demultiplexer implements Runnable  {

    public Demultiplexer(JsonObject jsonObj) throws Exception {
        this.jsonObj = jsonObj;
        init();
    }
    
    final public void init() throws Exception {
        poolID = jsonObj.getString("PoolID");
        missAllowed = jsonObj.getInt("Mismatches",1);
        // open the input fastq files
        expFile = fastqReader(jsonObj.getString("BarcodeFile"));
        wellFile = fastqReader(jsonObj.getString("WellBarcodeFile"));
        seqFile = fastqReader(jsonObj.getString("SequenceFile"));
        
        // open an output file for each well
        JsonArray wells = jsonObj.getJsonArray("Wells");
        int n = wells.size();
        for (int i=0 ; i<n ; ++i){
            JsonObject well = wells.getJsonObject(i);
            String wellID = well.getString("WellID");
            String wellBarcode = well.getString("WellBarcode").toUpperCase();
            String expBarcode = well.getString("ExpBarcode").toUpperCase();
            String key = expBarcode+wellBarcode;
            String fastq = well.getString("Fastq");
            File fastqFile = new File(fastq);
            File dir = fastqFile.getParentFile();
            if (!dir.exists()){
                dir.mkdirs();
            } else {
                // empty the directory
                File[] files = dir.listFiles();
                for (File f : files){
                    f.delete();
                }
            }
            WellData data = new WellData();
            data.id = wellID;
            data.file = new File(fastq);
            data.count = 0;
            data.stream = new PrintStream(fastq);
            dataMap.put(key, data);
            
            this.expBarcodes.add(expBarcode);
            this.wellBarcodes.add(wellBarcode);
        }
        
        // open an output file for failures 
        failures = new PrintStream(jsonObj.getString("FailuresFile"));
    }

    @Override
    public void run() {
        int totalReads = 0;
        int failedReads=0;
        try {
            FastqRead expBarcodeRead = FastqRead.read(expFile);
            while (expBarcodeRead != null){
                ++totalReads;
                FastqRead wellBarcodeRead  = FastqRead.read(wellFile);
                String molecularBarcode = wellBarcodeRead.sequence.substring(0, 6);
                FastqRead seqRead = FastqRead.read(seqFile);
                
                String expBarcode = matches(expBarcodeRead.sequence,this.expBarcodes ,missAllowed);
                if (expBarcode != null){
                    String wellBarcode = matches(wellBarcodeRead.sequence.substring(6),this.wellBarcodes,missAllowed);
                    if (wellBarcode != null){
                        String key = expBarcode+wellBarcode;
                        WellData data = dataMap.get(key);
                        PrintStream stream = data.stream;
                        data.count = data.count + 1;
                        success(seqRead,molecularBarcode,stream);
                    }
                    else {
                        failure(expBarcodeRead,wellBarcodeRead,seqRead);
                        ++failedReads;
                    }
                } 
                else {
                    failure(expBarcodeRead,wellBarcodeRead,seqRead);
                    ++failedReads;
                }
                expBarcodeRead = FastqRead.read(expFile);
            }
            // close all the files
            failures.close();
            for (WellData data : dataMap.values()){
                data.stream.close();
            }
            this.expFile.close();
            this.seqFile.close();
            this.wellFile.close();
            
            // save the read counts in the database
            PreparedStatement state = MySql.getMySql().getStatement(String.format("update SingleCellWell set ReadCount = ? where WellID = ?"));
            for (WellData data : dataMap.values()) {
                state.setInt(1,data.count);
                state.setString(2, data.id);
                state.execute();                       
            }
            state = MySql.getMySql().getStatement(String.format("update SingleCellPool set TotalReads = ? , FailedReads = ? where PoolID = ?"));
            state.setInt(1,totalReads);
            state.setInt(2,failedReads);
            state.setString(3, poolID); 
            state.execute();
            
            // filter PCR duplicates
            for (WellData data : dataMap.values()){
 //               filterByMolecularBarcode(data.file.getPath(),data.id);
            }
        } catch(Exception exc){
            exc.printStackTrace();
            return;
        }
    }
    static public void filterByMolecularBarcode(String inputFastqFile,String wellID)throws Exception {
        int count = 0;
        HashSet<String> set = new HashSet<>();
        String outputFastqFile = inputFastqFile.replaceFirst("demux.fastq","pair1.fastq.gz");
        BufferedReader reader = new BufferedReader(new FileReader(inputFastqFile));
        
        
        PrintStream writer = new PrintStream(new GZIPOutputStream(new PrintStream(outputFastqFile)));
        
        FastqRead read = FastqRead.read(reader);
        while (read != null){
            String key = read.sequence+read.head.substring(read.head.indexOf(" ")+1);
            if (!set.contains(key)){
                set.add(key);
                FastqRead.write(read, writer);
                ++count;
            }
            read = FastqRead.read(reader);
        }
        writer.close();
        reader.close();
        PreparedStatement state = MySql.getMySql().getStatement("update SingleCellWell set UniqueCount = ? where WellID = ?");
        state.setInt(1, count);
        state.setString(2, wellID);
        state.execute();
        
    }    
    static BufferedReader fastqReader(String fastq)throws Exception {
        if (fastq.endsWith(".gz")){
            return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(fastq))));
        }
        else {
            return new BufferedReader(new FileReader(fastq));    
        }
    }
    // count the number of mismatches between two string, up to a max
    static private int mismatches(String s1,String s2,int max){
        int ret = 0;
        for (int i=0 ; i<s1.length() ; ++i){
            if (s1.charAt(i) != s2.charAt(i)){
                ++ret;
                if (ret == max) return ret;
            }
        }
        return ret;
    }
    static String matches(String s,Set<String> barcodes, int allowedMisses){
        if (barcodes.contains(s)) return s;
        int maxMisses = allowedMisses+1;
        String ret = null;
        for (String barcode : barcodes){
            int miss = mismatches(s,barcode,maxMisses);
            if (miss <= allowedMisses){
                if (ret != null){
                    return null;  // input string matches more than one barcode
                } else {
                    ret = barcode;
                }
            }
        }
        return ret;
    }
    private void failure(FastqRead exp,FastqRead well,FastqRead seq){
        FastqRead read = new FastqRead();
        read.head = seq.head;
        read.quality = String.format("%s%s%s",exp.quality,well.quality,seq.quality);
        read.sequence = String.format("%s%s%s",exp.sequence,well.sequence,seq.sequence);
        FastqRead.write(read, failures);
    }
    private void success(FastqRead seq,String barcode,PrintStream stream){
        FastqRead read = new FastqRead();
        read.head = String.format("%s %s", seq.head,barcode);
        read.quality = seq.quality;
        read.sequence = seq.sequence;
        FastqRead.write(read, stream);
    }
    public void setJsonObj(Object o){
        jsonObj = (JsonObject)o;
    }
    
    // args[0] - file that has json object to configure the demultuiplex run
    static public void main(String[] args) throws Exception {
        JsonReader reader = Json.createReader(new FileReader(args[0]));
        JsonObject obj = reader.readObject();
        Demultiplexer dem = new Demultiplexer(obj);
        dem.run();
    }
    int missAllowed;  // the number of mismatches allowed in any barcode
    JsonObject jsonObj;  // has all the info to do the demultiplexing
    String poolID;
    BufferedReader expFile;   // the fastq file with the experimental barcodes
    BufferedReader wellFile;    // the fastq file with the molecular and well barcodes
    BufferedReader seqFile; // the fastq file with the read sequence
    
    TreeSet<String> expBarcodes = new TreeSet<>();
    TreeSet<String> wellBarcodes = new TreeSet<>();
//    TreeMap<String,PrintStream> wellMap = new TreeMap<>();  // map of the files to output the demultiplexed fastq
//    TreeMap<String,Integer> readCounts = new TreeMap<>();
//    TreeMap<String,String> wellIDs = new TreeMap<>();
    TreeMap<String,WellData> dataMap = new TreeMap<>();
    PrintStream failures;  // the file created with the reads that fail to demultiplex
    
    class WellData {
        String id;
        File file;
        PrintStream stream;
        Integer count;
    }
}
