#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <limits.h>
#include "LSM.h"
#include "Markov.h"
#include "SeqIO.h"


char alphabet[] = {'A','C','G','T'};
inline int alpha_to_code(char n) {
  switch(n) {
  case 'A':
    return 0;
  case 'C':
    return 1;
  case 'G':
    return 2;
  case 'T':
    return 3;
  }
  printf("Error: couldn't look up alphabet '%c' (\\%3.3o)\n", n, n);
  return -1;
}


char *IA_to_str(int *start, int *end) {
  int n = end - start + 1;
  char *s = (char*)malloc((1+n)*sizeof(char));
  int i;
  for( i = 0; i < n; i++)  s[i] = alphabet[start[i]];
  s[n] = '\0';
  //  printf("converting IA to str:");
  //for(i = 0; i < n; i++) printf(" %d=>%c", start[i],s[i]);
  //printf("\n");
  return s;
}

// all_kmers[k] is the list of all kmers for the given k
// in alphabetical order
// pow4[k] = 4^k.
// These are both valid for k = 0 .. max_k, and
// are maintained by calling make_kmer_list

char ***all_kmers;
int *pow4;  // k^4
int max_k = -1;


void free_all_kmers() {
  int k;
  for(k = 0; k <= max_k; k++) {
    int code;
    for(code = 0; code < pow4[k]; code++) {
      free(all_kmers[k][code]);
    }
    free(all_kmers[k]);
  }
  free(pow4);
  free(all_kmers);
}


// returns an array of length K+1 of
// char**'s corresponding to all kmers of 
// given length
// Save result in all_kmers and set max_k
// If all_kmers already exists make sure to
// free the memory first
// If max_k >= K then don't do anything.
char*** make_all_kmers(int K) {
  //  printf("make_all_kmers(%d)\n", K);
  if (max_k >= K) return all_kmers;

  // ak - new all_kmers
  char ***ak = (char***) malloc((K+1) * sizeof(char**));  
  // p4 - new pow4
  int *p4 = (int*)malloc((K+2)*sizeof(int));
  p4[K+1] = -1;
  
  ak[0] = (char**)malloc(sizeof(char*));
  ak[0][0] = strdup("");
  p4[0] = 1;

  int i,j,k;
  for(k = 1; k <= K; k++) {
    p4[k] = 4 * p4[k-1];
    ak[k] = (char**)malloc(p4[k] * sizeof(char*));
    for(i=0; i < p4[k-1]; i++) 
      for(j=0; j < 4; j++) {
	int code = 4*i + j;
	ak[k][code] = (char*) malloc ((k+1)*sizeof(char));
	strcpy(ak[k][code], ak[k-1][i]);
	ak[k][code][k-1] = alphabet[j];
	ak[k][code][k] = '\0';
      }
  }
  
  if(max_k > -1) free_all_kmers();

  max_k = K;
  pow4 = p4;
    
  return all_kmers = ak;
}


char* code_to_kmer(int code, int k) {
  if (max_k < k) {
    printf("Error: code_to_kmer(code = %d, k = %d) but max_k = %d < k\n",
	   code, k, max_k);
    exit(4);
  }
  if (code >= pow4[k]) {
    printf("Error: code_to_kmer(code = %d, k = %d) but n_kmers = %d <= code\n",
	   code, k, pow4[k]);
    exit(5);
  }
  //return all_kmers[k][code];
  char *s = all_kmers[k][code];
  return s;
}

// k must be supplied...only an initial substring will be used
int kmer_to_code(char *kmer, int k) {
  int f = 1;  // factor
  int s = 0;  // sum
  int i;
  for(i = k-1; i >= 0; i--)  {
    int alpha_code = alpha_to_code(kmer[i]);
    if(alpha_code == -1) {
      printf("Error: Couldn't convert kmer [%s] to a code\n", kmer);
      return -1;
    }
    s += alpha_code * f;
    f *= 4;
  }
  return s;
}
  


double Pos_p_sum(Pos* pos) {
  double s = 0;
  int i;
  for(i=0; i < 4; i++) s += pos->p[i];
  return s;
}

int Pos_OK(Pos* pos) {
  int i;
  for(i = 0; i < 4; i++) 
    if (pos->p[i] < 0) return 0;
  return fpe(Pos_p_sum(pos), 1.0);
}


  
Pos* set_Pos(Pos *pos, double fA, double fC, double fG, double fT)  {
  pos->p[0] = fA;
  pos->p[1] = fC;
  pos->p[2] = fG;
  pos->p[3] = fT;
  pos->p[4] = -1;

  int i;
  for(i = 0; i < 4; i++) pos->logp[i] = log2_rounded(pos->p[i]);
  pos->logp[4] = -1;

  if (Pos_OK(pos)) {
    return pos;
  }  else {
    for(i = 0; i < 4; i++) {
      if (pos->p[i] < 0) {
	printf("Error: new_Pos(%g, %g, %g, %g)--negative probability\n",
	       fA, fC, fG, fT);
	exit(6);
      }
    }
    double sum = Pos_p_sum(pos);
    fprintf(stderr, "Warning: new_Pos(%g, %g, %g, %g) not a prob (sum = %g)\n",
	    fA, fC, fG, fT, sum);
    for(i = 0; i < 4; i++) pos->p[i] /= sum;
    fprintf(stderr, "Rescaling to (%g, %g, %g, %g)\n",
	    pos->p[0], pos->p[1], pos->p[2], pos->p[3]);
    return set_Pos(pos, pos->p[0], pos->p[1], pos->p[2], pos->p[3]);
  }
  return pos;
}

Pos* new_Pos(double fA, double fC, double fG, double fT) {
  Pos *pos = (Pos*)malloc(sizeof(Pos));
  return set_Pos(pos, fA, fC, fG, fT);
}

Pos* copy_Pos(Pos *old) {
  return new_Pos(old->p[0], old->p[1], old->p[2], old->p[3]);
}

// A position should be a single with four numbers corresponding
// to fA, fC, fG, fT
Pos* next_Pos(FILE *fp) {
  double fA, fC, fG, fT;
  
  // this should be char line[MAX_LINE_LEN+1] but it won't compile (???)
  char line[MAX_LINE_LEN+1];
  if(next_line(fp, line))  {
    sscanf(line, "%lg%lg%lg%lg", &fA, &fC, &fG, &fT);
  }  else {
    printf("Error (next_Pos()): couldn't get a line from your file\n");
    exit(5);
  }
  
  return new_Pos(fA, fC, fG, fT);
}
  
void save_Pos(Pos *pos, FILE* fp, int save_what) {
  switch(save_what) {
  case PROBS:
    fprintf(fp, "%g\t%g\t%g\t%g\n",
	    pos->p[0], pos->p[1], pos->p[2], pos->p[3]);
    break;
  case LOGP:
    fprintf(fp, "%d\t%d\t%d\t%d\n",
	    pos->logp[0], pos->logp[1], pos->logp[2], pos->logp[3]);
  }
}

char Pos_generate(Pos *pos) {
  return alphabet[rand_from_p(pos->p)];
}



KmerCounter *new_KmerCounter(int k)  {
  if(k < 0)  {
    printf("new_KmerCounter() error: k (%d) < 0\n", k);
    exit(11);
  }

  KmerCounter *kc = (KmerCounter*)malloc(sizeof(KmerCounter));

  make_all_kmers(k);

  kc->k = k;

  kc->count = (double*)malloc((pow4[k]+1) * sizeof(double));
  reset_KmerCounter(kc);

  return kc;
}

void dump_KmerCounter(KmerCounter *kc) {
  int code;
  for(code = 0; code < pow4[kc->k]; code++)
    printf("%s\t%g\n", code_to_kmer(code, kc->k), kc->count[code]);
}


void destroy_KmerCounter(KmerCounter *kc)  {
  free(kc->count);
  free(kc);
}


KmerCounter *reset_KmerCounter(KmerCounter *kc)  {
  kc->count[pow4[kc->k]] = -1.0;
  int i;
  for(i=0; i < pow4[kc->k]; i++)  {
    kc->count[i] = 0;
  }
  return kc;
}

KmerCounter *add_kmers(KmerCounter *kc, char *s, double weight)  {
  int i;
  int n = strlen(s);
  for(i=0; i+kc->k <= n; i++)  {
    kc->count[kmer_to_code(s+i, kc->k)] += weight;
  }
  return kc;
}



Markov *new_Markov(int k)  {
  if(k < 0)  {
    printf("new_Markov() error: k (%d) < 0\n", k);
    exit(11);
  }
  Markov *m = (Markov*)malloc(sizeof(Markov));
  make_all_kmers(k);
  m->k = k;
  m->p = (Pos**)malloc(pow4[k] * sizeof(Pos*));
  int i;
  for(i = 0; i < pow4[k]; i++)  m->p[i] = NULL;

  return m;
}

Markov *copy_Markov(Markov *old) {
  Markov *new = new_Markov(old->k);

  int i;
  for(i = 0; i < pow4[old->k]; i++)  new->p[i] = copy_Pos(old->p[i]);
  return new;
}

void destroy_Markov(Markov *m)  {
  int code;
  for(code = 0; code < pow4[m->k]; code++)  free(m->p[code]);
  free(m->p);
  free(m);
}

Markov *load_Markov(FILE *fp)  {
  char line[MAX_LINE_LEN+1];
  next_line(fp, line);
  char k_str[101];
  int k;
  int n_scanned = sscanf(line, "%s = %d", k_str, &k);
  if((n_scanned < 2) || strcmp(k_str, "k"))  {
      printf("load_Markov() error: Couldn't parse line:\n%s\n", line);
      printf("Expecting \"k = %%d\" (make sure to include space around '=')\n");
      exit(7);
  }
  
  Markov *m = new_Markov(k);
  char *kmer = (char*)malloc((1+2*k) * sizeof(char));
  int n_kmers = 0;
  while(next_line(fp, line))  {
    double a,c,g,t;
    if(k == 0) {
      // you won't find any kmer at the beginning of the line
      sscanf(line, "%lg%lg%lg%lg", &a, &c, &g, &t);
      strcpy(kmer, "");
    }  else {
      sscanf(line, "%s %lg%lg%lg%lg", kmer, &a, &c, &g, &t);
      if (strlen(kmer) != k)  {
	printf("load_Markov() error: strlen(%s) = %d != %d\n",
	       kmer, strlen(kmer), k);
	exit(8);
      }
    }
    int code = kmer_to_code(kmer, k);
    if(m->p[code])  {
      printf("kmer '%s' repeated\n", kmer);
      exit(9);
    }
    m->p[code] = new_Pos(a,c,g,t);
    n_kmers++;
  }
  free(kmer);
  if (n_kmers != pow4[k])  {
    printf("load_Markov() error: #kmers read(%d) != 4^%d = %d\n",
           n_kmers, k, pow4[k]);
    exit(10);
  }
  return m;
}

void save_Markov(Markov *m, FILE *fp, int what)  {
  fprintf(fp, "k = %d\n", m->k);
  int code;
  for(code = 0; code < pow4[m->k]; code++)  {
    fprintf(fp, "%s\t", code_to_kmer(code, m->k));
    save_Pos(m->p[code], fp, what);
  }
}


Markov *new_Markov_from_KmerCounter(KmerCounter *kc)  {
  Markov *m = new_Markov(kc->k - 1);
  make_all_kmers(m->k);  //shouldn't be necessary, but just in case

  int c;
  for(c=0; c < pow4[m->k]; c++)  {
    double sum = 0;
    double f[4];
    // Note: this code uses a trick that assumes the
    // kmer-codes are calculated with the first characters most
    // significant. If the code in kmer_to_code() and code_to_kmer()
    // were to change then this would become a bug
    int i;
    for(i = 0; i < 4; i++)  {
      int c1 = 4*c + i;
      sum += (f[i] = kc->count[c1]);
    }
    if(sum)  {
      for(i = 0; i < 4; i++)  f[i] /= sum;
      m->p[c] = new_Pos(f[0], f[1], f[2], f[3]);
    }  else  {
      printf("new_Markov_from_KmerCounter warning():\n");
      printf("No counts for %s -> (using 0.25,0.25,0.25,0.25)\n",
             code_to_kmer(c, m->k));
      m->p[c] = new_Pos(0.25,0.25,0.25,0.25);
    }
  }
  return m;
}


Markov* new_Markov_from_FA(FILE *fp, int k) {
  KmerCounter *kc = new_KmerCounter(k+1);
  //printf("KmerCounter created\n");
  Seq *s;
  //printf("staritng s =next_Seq(fp)\n");
  while((s = next_Seq(fp)))  {
    save_Seq(s, stdout);
    add_kmers(kc, s->seq, 1.0);
    destroy_Seq(s);
  }

  Markov *m = new_Markov_from_KmerCounter(kc);
  destroy_KmerCounter(kc);
  return m;
}


Markov* new_Markov_increase_order(Markov *m, int k) {
  if (k < m->k) {
    printf("Markov_increase_order error: new k (%d) < old k (%d)\n",
	   k, m->k);
    exit(19);
  }

  if (k == m->k) return copy_Markov(m);

  Markov *new = new_Markov(k);
  
  int dk = k - m->k;
  int pow_dk = pow4[dk];
  int pow_mk = pow4[m->k];

  int old_code, d_code, new_code;
  // The new kmers are of the form xy where y is a (m->k)-mer
  // and x is a (dk)-mer, so in total they are k long.
  // Codes are in alphabetical order so that
  // code(xy) = 4^(m->k) * code(x) + code(y)
  // the Pos for xy should be the same as the old Pos for y
  
  for(old_code = 0; old_code < pow_mk; old_code++) {
    for(d_code = 0; d_code < pow_dk; d_code++) {
      new_code = pow_mk * d_code + old_code;
      new->p[new_code] = copy_Pos(m->p[old_code]);
    }
  }

  return new;
}      

Markov* new_Markov_decrease_order(Markov *m, int k) {
  if (k > m->k) {
    printf("Markov_decrease_order error: new k (%d) > old k (%d)\n",
	   k, m->k);
    exit(19);
  }

  if (k == m->k) return copy_Markov(m);

  Markov *new = new_Markov(k);
  
  int dk = m->k - k;
  int pow_mk = pow4[m->k];
  int pow_k = pow4[k];
  int pow_dk = pow4[dk];

  // The probability for kmer y -> nt z is the average over the
  // dk-mers x for P(xy->z)
  // Code(xy) = 4^k * code(x) + code(y)

  int i,old_code, y_code;
  for(y_code = 0; y_code < pow_k; y_code++) {
    Pos *pos = new->p[y_code] = (Pos*)malloc(sizeof(Pos));
    for(i = 0; i < 4; i++) pos->p[i] = 0;
    for(old_code = y_code; old_code < pow_mk; old_code += pow_k) {
      for(i = 0; i < 4; i++)  pos->p[i] += m->p[old_code]->p[i];
    }
    for(i = 0; i < 4; i++)  pos->p[i] /= pow_dk;
    set_Pos(pos, pos->p[0],pos->p[1],pos->p[2],pos->p[3]);
  }

  return new;
}      
  


int Markov_score(Markov* m, char *s) {
  int i;
  int score = 0;
  int n = strlen(s);
  for(i=0; i+m->k < n; i++) {
    int kmer_code = kmer_to_code(s+i, m->k);
    if(kmer_code == -1) {
      printf("Markov_score error: Couldn't score sequence [%s]\n", s);
      return INT_MIN;
    }
    int alpha_code = alpha_to_code(s[i+m->k]);
    score += m->p[kmer_code]->logp[alpha_code];
  }
  return score;
}

double Markov_prob(Markov* m, char *s) {
  int i;
  double prob = 1;
  int n = strlen(s);
  for(i=0; i+m->k < n; i++) {
    int kmer_code = kmer_to_code(s+i, m->k);
    int alpha_code = alpha_to_code(s[i+m->k]);
    prob *= m->p[kmer_code]->p[alpha_code];
  }
  return prob;
}

// Three options:
// *term* indicates to do the calculation starting at the end of the sequence
//   Note: this has an effect even if cum is 0 (see example)
// *cum* indicates that cumulative probability (rather than one calculating
// each position independently) should be returned
// *what* is either PROBS or LOGP
//
// The first k characters are taken as condition sequence, and the sequence
// index 1 begins after that. So we have seq = x1...xk|x{k+1},x{k+2},...,x{k+n}
// Examples (for i = 0 .. n)
// (1) term = 0, cum = 1, what = PROBS
// vec[i] = Prob(seq[k+1 .. k+i | 1 .. k])
//   Note: this is 1 by definition for i = 0
// (2) term = 1, cum = 1, what = PROBS
// vec[i] = Prob(seq[k+i+1 .. k+n | i+1 .. i+k])
//   Note: this is 1 by definition if i = n
// (3) term = 1, cum = 0, what = LOGP
// vec[i] = Rounded Log Prob(seq[k+i+1 | i+1 .. i+k])
//   Note: This is not defined if i = n (even though memory is allocated)
// (4) term = 0, cum = 0, what = LOGP
// vec[i] = Rounded log Prob(seq[k+i |i .. k+i-1])
//   Note: Note defined if i = 0 (even though memory is allocated)

// This function returns a void* which should be immediately cast
// to a double* (if what =PROBS) or an int* (if what ==LOGP). This
// is a hack to simulate function overloading
void *Markov_vec(Markov* m, char *s, int term, int cum, int what) {
  int i;
  int n = strlen(s) - m->k;
  if(n < 0) {
    printf("Error: Markov_vec required at least %d characters for conditioning\n", m->k);
    printf("You only provided %d: [%s]\n", strlen(s), s);
    exit(16);
  }
  double *prob;
  int *score;
  int probs = (what == PROBS);

  //probs ? (prob = (double*)malloc( (n+2)*sizeof(double) )) :
  //(score = (int*)malloc((n+2)*sizeof(int)));
  if(probs)  {
    prob = (double*) malloc((n+2)*sizeof(double));
  }  else {
    score = (int*)malloc((n+2)*sizeof(int));
  }

  // set first/last term to prob 1/score 0
  // This may not be necessary in some cases, but it won't hurt
  probs ? (prob[0] = 1) : (score[0] = 0);
  probs ? (prob[n] = 1) : (score[n] = 0);

  // Array terminator
  probs ? (prob[n + 1] = -1) : (score[n+1] = -1);

  for(term ? (i = n - 1) : (i = 1);
      term ? (i >= 0)    : (i <= n);
      term ? (i--)       : (i++)) {
    // kmer on which i^th entry is conditioned:
    // starts at i, or i - 1 in 0-based indexing for INITIAL
    // starts at i in 0-based indexing for TERMINAL
    // character emitted at this step:
    // i+k in 1-based indexing or i+k-1 in 0-based indexing for INITIAL
    // i+k in 0-based indexing for TERMINAL
    int kmer_code = kmer_to_code(term ? (s+i) : (s+i - 1), m->k);
    int alpha_code = alpha_to_code(s[term ? (i+m->k) : (i+m->k-1)]);
    
    
    if(probs) {
      prob[i] = m->p[kmer_code]->p[alpha_code];
      if(cum) prob[i] *= prob[term ? (i+1) : (i-1)];
    }  else {
      score[i] = m->p[kmer_code]->logp[alpha_code];
      if(cum) {
	int cum_score = score[term ? (i+1) : (i-1)];
	if((score[i] == INT_MIN) || (cum_score == INT_MIN)) {
	  score[i] = INT_MIN;
	}  else {
	  score[i] += cum_score;
	}
      }
    }
  }

  return probs ? (void*)prob : (void*)score;
}



void init_kmers() {make_all_kmers(6);}
int pow4_fcn(int k) {
  if(k > max_k) make_all_kmers(k);
  return pow4[k];
}

// NOTE: This doesn't check that the conditional string
// has the right length
inline char Markov_generate_char(Markov *m, char *cond) {
  return Pos_generate(m->p[kmer_to_code(cond, m->k)]);
}

char *Markov_generate(Markov *m, char *cond, int len) {
  char *c = strdup(cond);
  char *c2 = strdup(c);
  char *ret = (char*)malloc((len+1)*sizeof(char));
  ret[len] = '\0';
  int i;
  for(i = 0; i < len; i++) {
    ret[i] = Markov_generate_char(m, c);
    if(m->k > 0) {
      strcpy(c2, c+1);
      c2[m->k - 1] = ret[i];
      //printf("strlen(c2) = %d\tm->k = %d\n", strlen(c2), m->k);
    }
    strcpy(c, c2);
  }
  free(c);
  free(c2);
  return(ret);
}
    

double Markov_rel_ent(Markov *m1, Markov *m2) {  
  int k_big, k_small;

  if(m1->k < m2->k) {
    k_big = m2->k;
    k_small = m1->k;
  }  else {
    k_big = m1->k;
    k_small = m2->k;
  }
  int pow_big = pow4[k_big];
  int pow_small = pow4[k_small];

  // Go through the k_big-mer by z = xy, x = dk-mer, y = k_small-mer,
  // and add term to rel_ent
  double rel_ent = 0;
  int small_code; /* code for y */
  int big_code;   /* code for z */
  for(small_code = 0; small_code < pow_small; small_code++)
    for(big_code = small_code; big_code < pow_big; big_code += pow_small) {
      int code1 = ((m1->k == k_big) ? big_code : small_code);
      int code2 = ((m2->k == k_big) ? big_code : small_code);
      int i;
      for(i = 0; i < 4; i++) {
	double p1 = m1->p[code1]->p[i];
	double p2 = m2->p[code2]->p[i];
	if(p1 > 0) rel_ent += p1 * log2(p1/p2);
      }
    }
  
  rel_ent /= pow_big;

  return(rel_ent);
}

void Markov_to_R(Markov* m, FILE* fp) {
  fprintf(fp, "list(k=%d,p=matrix(c(", m->k);
  int code;
  for(code=0; code < pow4[m->k]; code++) {
    int i=0;
    for(i=0; i < 4; i++) {
      if(i > 0 || code > 0)  fprintf(fp, ",");
      fprintf(fp,"%g",m->p[code]->p[i]);
    }
  }
  fprintf(fp,"),\n     byrow=T, ncol=4, dimnames=list(c(");
  for(code = 0; code < pow4[m->k]; code++) {
    if(code > 0) fprintf(fp,",");
    fprintf(fp,"'%s'", code_to_kmer(code,m->k));
  }
  fprintf(fp,"),c('A','C','G','T'))))\n");
}
