/* Dynamic binning of sequences according to expected k-mer frequencies
   (expected k-mer frequencies are modelled by a first-order markov model)

   compile with:
   gcc -c ExpBin.c -o ExpBin.o
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "ExpBin.h"
#include "Markov.h"


/* initialize from an array of sequence strings:
      - count dinucleotides and calculate dinucleotides frequencies (for pseudocounts) */
Binning *new_Binning(char** seqarray, int n, int k, int mode) {
  Binning *b = (Binning*)malloc(sizeof(Binning));
  int i,j;

  if(mode != 1 && mode != 2) {
    printf("ERROR in new_Binning(): unknown mode %d (should be 1:by_G+C or 2:by_exp_freq)\n", mode);
    exit(0);
  }
  b->mode      = mode;

  b->seqs      = seqarray;
  b->n_seqs    = n;

  b->k         = k;
  b->n_kmers   = pow4_fcn(k);

  if(mode == 1) {
    b->gc        = (unsigned short int*)calloc(b->n_seqs, sizeof(unsigned short int));
    b->breaks1   = (unsigned short int*)calloc(N_BINS,    sizeof(unsigned short int));
    b->bins1     = (unsigned short int*)calloc(b->n_seqs, sizeof(unsigned short int));
    for(j = 0; j < b->n_seqs; j++)
      b->bins1[j] = N_BINS+1; // indicates that bin has not yet been calculated

    b->counts    = NULL;
    b->freqs     = NULL;
    b->f_exp     = NULL;
    b->breaks2   = NULL;
    b->bins2     = NULL;

    define_bins_byGC(b);      // this will fill in b->gc, b->bins1, and temporarily b->breaks1

  } else {
    b->gc        = NULL;
    b->breaks1   = NULL;
    b->bins1     = NULL;

    b->counts    = (int**)malloc(2 * sizeof(int*));
    b->counts[0] = (int*) calloc(4,  sizeof(int));
    b->counts[1] = (int*) calloc(16, sizeof(int));
    b->tot[0]    = 0;
    b->tot[1]    = 0;
    b->freqs     = (double**)malloc(2 * sizeof(double*));
    b->freqs[0]  = (double*) calloc(4,  sizeof(double));
    b->freqs[1]  = (double*) calloc(16, sizeof(double));

    b->f_exp     = (float**)malloc(b->n_kmers * sizeof(float*));
    for(i = 0; i < b->n_kmers; i++)
      b->f_exp[i] = (float*)calloc(n, sizeof(float));

    b->breaks2    = (float**)malloc(b->n_kmers * sizeof(float*));
    for(i = 0; i < b->n_kmers; i++)
      b->breaks2[i] = (float*)calloc(N_BINS, sizeof(float));

    b->bins2      = (unsigned short int**)malloc(b->n_kmers * sizeof(unsigned short int*));
    for(i = 0; i < b->n_kmers; i++) {
      b->bins2[i] = (unsigned short int*)malloc(b->n_seqs * sizeof(unsigned short int));
      for(j = 0; j < b->n_seqs; j++)
	b->bins2[i][j] = N_BINS+1; // indicates that bin has not yet been calculated
    }

    define_bins_byExpFreq(b); // this will fill in b->counts, b-tot, b->freqs, b->bins2, and temporarily b->f_exp and b->breaks2
  }

  return b;
}

/* sort sequences according to G+C (no pseudocounts) and calculate
   binning breaks for N_BINS equally sized bins */
void define_bins_byGC(Binning* b) {
  int i, j;
  char* kmer;

  // calculate G+C (for each sequences)
  for(i = 0; i < b->n_seqs; i++) {
    b->gc[i] = 0;
    char* seq = b->seqs[i];
    for(j = 0; j < strlen(seq); j++)
      if(seq[j] == 'G' || seq[j] == 'C')
	b->gc[i]++;
  }

  /*
  printf("\nG+C (first ten sequences):\nseq:");
  for(i = 0; i < 10; i++) printf("\t%d", i+1);
  printf("\nG+C:");
  for(i = 0; i < 10; i++) printf("\t%d", b->gc[i]);
  printf("\n");
  */

  // now sort gc and fill in b->breaks1
  unsigned short int gc[b->n_seqs]; // copy of b->gc for sorting
  double bin_size = (double) b->n_seqs / N_BINS;
  for(i = 0; i < b->n_seqs; i++)
    gc[i] = b->gc[i];
  qsort(gc, b->n_seqs, sizeof(unsigned short int), compare_unsigned_short_ints);

  /*
  printf("before sort: ");
  for(i = 0; i < 10; i++) printf("% 3d", b->gc[i]);
  printf(" ...");
  for(i = b->n_seqs - 11; i < b->n_seqs; i++) printf(" %d", b->gc[i]);
  printf("\n");
  printf(" after sort: ");
  for(i = 0; i < 10; i++) printf("% 3d", gc[i]);
  printf(" ...");
  for(i = b->n_seqs - 11; i < b->n_seqs; i++) printf(" %d", gc[i]);
  printf("\n");
  */

  for(j = 0; j < N_BINS - 1; j++)
    b->breaks1[j] = gc[(int)((j+1) * bin_size)];
  b->breaks1[N_BINS - 1] = gc[b->n_seqs - 1];

  // for each sequence, calculate and store bin
  for(i = 0; i < b->n_seqs; i++)
    b->bins1[i] = get_bin_for_seq(b, 0, i);

  // clean up
  free(b->breaks1);
  b->breaks1 = NULL;
}

/* for each k-mer, sort sequences by expected counts (based on
   dinucleotide content and PRIORWEIGHT * pseudocounts) and
   calculate binning breaks for N_BINS equally sized bins */
void define_bins_byExpFreq(Binning* b) {
  int code, i, j;
  char* kmer;
  double tot_pc;

  // allocate memory for mono-/dinucleotide counts (over all sequences)
  unsigned short int** n_mono = (unsigned short int**)malloc(b->n_seqs * sizeof(unsigned short int*));
  unsigned short int** n_di   = (unsigned short int**)malloc(b->n_seqs * sizeof(unsigned short int*));
  int* tot_mono = (int*)calloc(b->n_seqs, sizeof(int));
  int* tot_di   = (int*)calloc(b->n_seqs, sizeof(int));

  // calculate mono-/dinucleotide counts (over all sequences)
  for(i = 0; i < b->n_seqs; i++) {
    n_mono[i] = (unsigned short int*)malloc(4  * sizeof(unsigned short int));
    n_di[i]   = (unsigned short int*)malloc(16 * sizeof(unsigned short int));

    tot_mono[i] = count_kmers(1, b->seqs[i], n_mono[i]);
    b->tot[0]  += tot_mono[i];
    for(j = 0; j < 4; j++)
      b->counts[0][j] += (int) n_mono[i][j];

    tot_di[i]   = count_kmers(2, b->seqs[i], n_di[i]);
    b->tot[1]  += tot_di[i];
    for(j = 0; j < 16; j++)
      b->counts[1][j] += (int) n_di[i][j];
  }

  /*
  printf("\nmono:");
  for(i = 0; i < 4; i++) printf(" %d", b->counts[0][i]);
  printf("\n  di:");
  for(i = 0; i <16; i++) printf(" %d", b->counts[1][i]);
  printf("\n");
  */

  // calculate mono-/dinucleotide frequencies (over all sequences)
  for(j = 0; j < 4; j++)
    b->freqs[0][j] = (double) b->counts[0][j] / b->tot[0];
  for(j = 0; j < 16; j++)
    b->freqs[1][j] = (double) b->counts[1][j] / b->tot[1];

  // calculate expected k-mer frequencies (for each sequence)
  double f_mono[4];
  double f_di[16];
  for(i = 0; i < b->n_seqs; i++) {
    tot_pc = PRIORWEIGHT * tot_mono[i];
    for(j = 0; j < 4; j++)
      f_mono[j] = (n_mono[i][j] + b->freqs[0][j] * tot_pc) / (tot_mono[i] + tot_pc);

    tot_pc = PRIORWEIGHT * tot_di[i];
    for(j = 0; j < 16; j++)
      f_di[j]   = (n_di[i][j] + b->freqs[1][j] * tot_pc) / (tot_di[i] + tot_pc);

    for(code = 0; code < pow4_fcn(b->k); code++) {
      kmer  = code_to_kmer(code, b->k);
      b->f_exp[code][i] = expected_freq(b, kmer, f_mono, f_di);
    }
  }

  // now sort f_exp and fill in b->breaks2
  float f_exp[b->n_seqs]; // copy of b->f_exp[code] for sorting
  double bin_size = (double) b->n_seqs / N_BINS;
  for(code = 0; code < b->n_kmers; code++) {
    for(i = 0; i < b->n_seqs; i++)
      f_exp[i] = b->f_exp[code][i];
    qsort(f_exp, b->n_seqs, sizeof(float), compare_floats);

    /*
    printf("before sort: ");
    for(i = 0; i < 10; i++) printf(" %f", b->f_exp[code][i]);
    printf(" ...");
    for(i = b->n_seqs - 11; i < b->n_seqs; i++) printf(" %f", b->f_exp[code][i]);
    printf("\n");
    printf("after sort: ");
    for(i = 0; i < 10; i++) printf(" %f", f_exp[i]);
    printf(" ...");
    for(i = b->n_seqs - 11; i < b->n_seqs; i++) printf(" %f", f_exp[i]);
    printf("\n");
    */

    for(j = 0; j < N_BINS - 1; j++)
      b->breaks2[code][j] = f_exp[(int)((j+1) * bin_size)];
    b->breaks2[code][N_BINS - 1] = f_exp[b->n_seqs - 1];
  }

  // for each (kmer,sequence)-pair, calculate and store bin
  for(code = 0; code < b->n_kmers; code++)
    for(i = 0; i < b->n_seqs; i++)
      b->bins2[code][i] = get_bin_for_seq(b, code, i);

  // clean up
  for(i = 0; i < b->n_kmers; i++) {
    free(b->f_exp[i]);
    free(b->breaks2[i]);
  }
  free(b->f_exp);
  free(b->breaks2);
  b->f_exp = b->breaks2 = NULL;
  for(i = 0; i < b->n_seqs; i++) {
    free(n_mono[i]);
    free(n_di[i]);
  }
  free(n_mono);
  free(n_di);
  free(tot_mono);
  free(tot_di);
}

/* get the bin for a k-mer (code) with a given expected freq (f_exp) */
int get_bin_for_kmer_expfreq(Binning* b, int code, float f_exp) {
  if( b->breaks2 != NULL ) {
    int i = 0;
    while(b->breaks2[code][i] < f_exp && i < N_BINS - 1) {
      i++;
    }
    // printf("kmer=%s, f_exp=%.4f, bin=%d\n", code_to_kmer(code, b->k), f_exp, i);
    return i;
  } else {
    printf("ERROR in get_bin_for_kmer_expfreq(): cannot calculate bin (float** breaks2 has been freed)\n");
    exit(0);
  }
}

/* get the bin for a k-mer (code) with a given expected freq (f_exp) */
int get_bin_for_gc(Binning* b, int gc) {
  if( b->breaks1 != NULL ) {
    int i = 0;
    while(b->breaks1[i] < gc && i < N_BINS - 1) {
      i++;
    }
    return i;
  } else {
    printf("ERROR in get_bin_for_GC(): cannot calculate bin (int* breaks1 has been freed)\n");
    exit(0);
  }
}

/* get the bin for a k-mer (code), use pre-calculated expected count in b for sequence s*/
int get_bin_for_seq(Binning* b, int code, int s) {
  if( b->mode == 1 ) {
    if( b->bins1 != NULL && b->bins1[s] != N_BINS+1) {
      return b->bins1[s];
    } else if( b->gc != NULL && b->breaks1 != NULL) {
      return get_bin_for_gc(b, b->gc[s]);
    } else {
      printf("ERROR in get_bin_for_seq(): don't know how to calculate bin without 'bins1','gc','breaks1'\n");
      exit(0);
    }
  } else if( b->mode == 2) {
    if( b->bins2 != NULL && b->bins2[code][s] != N_BINS+1) {
      return b->bins2[code][s];
    } else if(b->f_exp != NULL) {
      float f_exp = b->f_exp[code][s];
      return get_bin_for_kmer_expfreq(b, code, f_exp);
    } else {
      printf("ERROR in get_bin_for_seq(): don't know how to calculate bin without 'bins2','f_exp'\n");
      exit(0);
    }
  }
}

/* comparison function for qsort */
int compare_doubles( const void* a, const void* b ) {
  double* arg1 = (double*) a;
  double* arg2 = (double*) b;
  if( *arg1 < *arg2 ) return -1;
  else if( *arg1 == *arg2 ) return 0;
  else return 1;
}

/* comparison function for qsort */
int compare_floats( const void* a, const void* b ) {
  float* arg1 = (float*) a;
  float* arg2 = (float*) b;
  if( *arg1 < *arg2 ) return -1;
  else if( *arg1 == *arg2 ) return 0;
  else return 1;
}

/* comparison function for qsort */
int compare_unsigned_short_ints( const void* a, const void* b ) {
  unsigned short int* arg1 = (unsigned short int*) a;
  unsigned short int* arg2 = (unsigned short int*) b;
  if( *arg1 < *arg2 ) return -1;
  else if( *arg1 == *arg2 ) return 0;
  else return 1;
}

/* count k-mers in seq and store them in counts
   return total nb of k-mers
   counts is expected to be of the right size */
int count_kmers(int k, char* seq, unsigned short int* counts) {
  int i, code;

  int tot = 0;
  for(i = 0; i < pow4_fcn(k); i++)
    counts[i] = 0;
  
  for(i = 0; i <= strlen(seq) - k; i++) {
    code = kmer_to_code(seq + i, k);
    counts[code]++;
    tot++;
  }

  return tot;
}

/* calculate expected frequency of k-mers in a sequence
   based on a 1st-order Markov model */
float expected_freq(Binning* b, char* kmer, double* f_mono, double* f_di) {
  int i, code_mono, code_di;
  double exp;

  exp = f_di[kmer_to_code(kmer, 2)];
  for(i = 1; i < strlen(kmer) - 1; i++) {
    code_mono = kmer_to_code(kmer + i, 1);
    code_di   = kmer_to_code(kmer + i, 2);
    exp *= (f_di[code_di] / f_mono[code_mono]);
  }

  return (float) exp;
}

/* free memory */
void destroy_Binning(Binning* b) {
  int i;
  if( b->gc      != NULL )   free(b->gc);
  if( b->breaks1 != NULL )   free(b->breaks1);
  if( b->bins1   != NULL )   free(b->bins1);
  if( b->counts  != NULL )   {
    for(i = 0; i < 2; i++)
      free(b->counts[i]);
    free(b->counts);
  } 
  if( b->freqs   != NULL )   {
    for(i = 0; i < 2; i++)
      free(b->freqs[i]);
    free(b->freqs);
  }
  if( b->f_exp   != NULL )   {
    for(i = 0; i < b->n_kmers; i++)
      free(b->f_exp[i]);
    free(b->f_exp);
  }
  if( b->breaks2 != NULL )   {
    for(i = 0; i < b->n_kmers; i++)
      free(b->breaks2[i]);
    free(b->breaks2);
  }
  if( b->bins2   != NULL )   {
    for(i = 0; i < b->n_kmers; i++)
      free(b->bins2[i]);
    free(b->bins2);
  }
  free(b);
}
