#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "Likelihood_score.h"
#include "Genome_profile_aux.h"
#include "optimization.h"
#define VERBOSE 0

static void Normalize_reads(float *y, int nn, int BOXSIZE);
static void Sum_boxes(float *y, long *x, long n, long *x1, long n1);
static float Lik_score_cr_opt_r(float *E, float *AIC, long *N_peaks,
				float *r_in, float *r_out,
				long nnc, float PARAPEN,
				float *ye, float mre, float *yc, float mrc);
static float Lik_score(float *enhancement, 
		       long *N_peaks, float *r_out,
		       float *yye, float mre,
		       float *yyc, float mrc,
		       long nnc, float r,
		       float PARAPEN,
		       int OUT);


void  Set_control(float **yc_box, int Nchr, long *nn, float value)
{
  int k; long i;
  for(k=0; k<Nchr; k++){
    float *y=yc_box[k];
    for(i=0; i<nn[k]; i++){*y=value; y++;}
  }
}

float Make_box(float **y_box, float **y,
	       long **x, long *nn, int N, float DAMP, int *WIN, float l_EPS)
{
  /* Weighted average of the number of reads in neighboring boxes
   */

  float BOXSIZE=x[0][1]-x[0][0];
  float DampFact=exp(-BOXSIZE/DAMP);
  *WIN=l_EPS*DAMP/BOXSIZE;

  int k, ic; long i, j;

  float *Damp=malloc((*WIN+1)*sizeof(float));
  Damp[0]=1.0;
  float norm=Damp[0];
  for(k=1; k<= *WIN; k++){
    Damp[k]=Damp[k-1]*DampFact; norm+=2*Damp[k];
  }
  for(ic=0; ic<N; ic++){
    long *xc=x[ic], nnc=nn[ic];
    float *y_new=y_box[ic], *y_old=y[ic];
    for(i=0; i<nnc; i++){
      float ww=1, ysum=(*y_old);
      float *w=Damp+1, *y_ptr=y_old+1; j=i+1;
      for(k=1; k<=*WIN; k++){
	if(j>=nnc)break;
	ysum+=(*y_ptr)*(*w); ww+=(*w); w++; j++; y_ptr++;
      }
      w=Damp+1; y_ptr=y_old-1; j=i-1; 
      for(k=1; k<=*WIN; k++){
	if(j<0)break;
	ysum+=(*y_ptr)*(*w); ww+=(*w); w++; j--; y_ptr--;
      }
      *y_new=(ysum/ww);
      y_new++; y_old++;
    }
  }
  free(Damp);
  return(norm);
}

float Likelihood_score_opt_r(float **enhancement, float *AIC, float *rr,
			     float **ye, float mre, float **yc,
			     float mrc, long *nn, int Ncr, float PARAPEN)
{
  double sum=0, w=0; long N_peaks=0, N_p;
  int ic; float scr, r_in, r_out, AIC_chr;
  if(VERBOSE)printf("# chr score r  m\n");
  for(ic=0; ic<Ncr; ic++){
    scr=Lik_score_cr_opt_r(enhancement[ic], &AIC_chr, &N_p,
			   &r_in, &r_out, nn[ic], PARAPEN,
			   ye[ic], mre, yc[ic], mrc);
    float ww=nn[ic]*BOXSIZE;
    if(VERBOSE)printf("%d %.3f  %.2f %.2f   %.1f %.1f\n", 
		      ic, scr/ww, r_in, r_out, mre, mrc);
    sum+=scr; w+=ww; rr[ic]=r_in; (*AIC)+=AIC_chr; N_peaks+=N_p;
  }
  (*AIC)/=N_peaks;
  return(sum/w);
}

float Lik_score_cr_opt_r(float *E, float *AIC, long *N_peaks,
			 float *r_in, float *r_out, long nnc, float PARAPEN,
			 float *ye, float mre, float *yc, float mrc)
{
  float r1, r2, r3, s1, s2, s3;
  r1=0.4;
  s1=Lik_score(E, N_peaks, r_out, ye, mre, yc, mrc, nnc, r1, PARAPEN, 0);
  r2=0.5;
  s2=Lik_score(E, N_peaks, r_out, ye, mre, yc, mrc, nnc, r2, PARAPEN, 0);
  r3=0.6;
  s3=Lik_score(E, N_peaks, r_out, ye, mre, yc, mrc, nnc, r3, PARAPEN, 0);
  float smax=s1, rmax=r1, s0, r0; int it;
  if(s2>smax){smax=s2; rmax=r2;}
  if(s3>smax){smax=s3; rmax=r3;}
  for(it=0; it<100; it++){
    r0=Find_max_quad(r1, r2, r3, s1, s2, s3, 0, 1);
    s0=Lik_score(E, N_peaks, r_out, ye, mre, yc, mrc, nnc, r0, PARAPEN, 0);
    if(VERBOSE)printf("%.3f %.2f\n", r0, s0/nnc);
    if(s0<smax)break;
    if((s0-smax)/fabs(smax)<0.01){smax=s0; rmax=r0; break;}
    Rearrange_points(&r1, &r2, &r3, &s1, &s2, &s3, r0, s0);
    smax=s0; rmax=r0;
  }
  *r_in=rmax;
  *AIC=Lik_score(E, N_peaks, r_out, ye, mre, yc, mrc, nnc, rmax, PARAPEN, 1);
  return(smax);
}

float Lik_score(float *enhancement, 
		long *N_peaks, float *r_out,
		float *yye, float mre,
		float *yyc, float mrc,
		long nnc, float r,
		float PARAPEN,
		int OUT)
{
  double score=0, reads_exp=0;
  float rmre=mre*r, ycontr;
  float log_c=log(mrc);
  float log_e=log(rmre);
  float log_all=log(mrc+rmre);
  float coeff=rmre/mrc;
  float ce=mrc/((1-r)*mre);
  long i, np=0;
  float *ye=yye, *yc=yyc, *E=enhancement, delta; 
  for(i=0; i<nnc; i++){
    ycontr=((*yc)+mrc)*0.5;
    //if(*yc >mrc){ycontr=*yc;}else{ycontr=mrc;}
    float ya=(*ye)+ycontr;
    float lik0=ycontr*log_c+(*ye)*log_e-ya*log_all;
    float ynull=ycontr*coeff;
    if(OUT==0){score+=lik0;}
    if((*ye)<ynull){if(OUT)*E=0; goto next;}
    delta=0;
    if(ya)delta+=ya*(log_all-log(ya));
    if(ycontr)delta+=ycontr*(log(ycontr)-log_c);
    if(*ye)delta+=(*ye)*(log(*ye)-log_e);
    if(OUT)*E=delta;
    if(delta>PARAPEN){
      score+=(delta-PARAPEN);
      if(OUT){
	float dy=(*ye)-ynull;
	reads_exp+=dy;
	//*E=ce*dy/ycontr;
	np++;
      }
    }
  next:
    ye++; yc++; if(OUT)E++;
  }
  if(OUT){
    *r_out=1.-reads_exp/(mre*nnc);
    *N_peaks=np;
  }
  /* If OUT=0, score is the total likelihood - Parameter penalties
     otherwise score is summed only when the null model is rejected */
  return(score);
}

void Read_contr_exp(int CONTR, long ***x_coord,
		    float ***y_contr, float ***y_exper,
		    long **nn, char **file_c, char **file_e, int N, char *outp)
{
  int k; long i, n2;
  long  **x2=NULL;
  if(CONTR)x2=malloc(N*sizeof(long *));
  *x_coord=malloc(N*sizeof(long *));
  *y_contr=malloc(N*sizeof(float *));
  *y_exper=malloc(N*sizeof(float *));
  *nn=malloc(N*sizeof(long));
  FILE *file_out=fopen(outp, "w"); 
  for(k=0; k<N; k++){
    printf("Reading %s\n", file_e[k]);
    (*nn)[k]= Read_file(&(*x_coord)[k], &(*y_exper)[k], file_e[k]);
    fprintf(file_out, "Chr %d %ld lines read\n", k+1, (*nn)[k]);
    printf("Chr %d %ld lines read\n", k+1, (*nn)[k]);
    if(CONTR){
      n2=       Read_file(&x2[k], &(*y_contr)[k], file_c[k]);
      if((*nn)[k]!=n2){
	printf("WARNING, different number of lines in files\n");
	printf("%s (%d) and\n%s (%d)\n", file_e[k], (*nn)[k], file_c[k], n2);
	if((n2<0.9*(*nn)[k])||(n2>1.1*(*nn)[k]))exit(8);
	if(n2 > (*nn)[k]){
	  Sum_boxes((*y_contr)[k], x2[k], n2, (*x_coord)[k], (*nn)[k]);
	}else if(n2 < (*nn)[k]){
	  Sum_boxes((*y_exper)[k], (*x_coord)[k], (*nn)[k], x2[k], n2);
	  (*nn)[k]=n2;
	}
      }
      float r=Corr_coeff((*y_contr)[k], (*y_exper)[k], (*nn)[k]);
      fprintf(file_out, "Corr(exper,control)= %.3f\n", r);
    }
  }
  fclose(file_out);
  // Check consistence
  if(CONTR){
    for(k=0; k<N; k++){
      long *xe=(*x_coord)[k], *xc=x2[k];
      for(i=0; i<(*nn)[k]; i++){
	if(xc[i]!=xe[i]){
	  printf("ERROR, different coordinates: %ld %ld\n", xc[i], xe[i]);
	  printf("%s\n%s\n", file_c[k], file_e[k]);
	  printf("Statistics written in %s\n", outp); exit(8);
	}
      }
      free(x2[k]);
    }
    free(x2);
  }
}

void Print_score(long **x_contr, float **Zscore, long *nn,
		 char **file_e, int N, int WIN, float minmin)
{
  char nameout[400]; int k; long i;
  for(k=0; k<N; k++){
    char ext[20]; sprintf(ext, "_W%d.scr", WIN);
    Output_name(nameout, file_e[k], ext);
    long *xc=x_contr[k];
    FILE *file_out=fopen(nameout, "w");
    fprintf(file_out, "#WINDOW= %d\n", WIN);
    printf("Chr %d output: %s\n", k+1, nameout);
    for(i=0; i<nn[k]; i++){
      fprintf(file_out, "%ld %.3f\n", xc[i], Zscore[k][i]);
      //fprintf(file_out, "%ld %.3f %.3f\n", xc[i], Zscore[k][i], fact[i]);
    }
    fclose(file_out);
  }
}

float Mean_counts(float *sd, float **y, long *nn, int N)
{
  int ic, i; double sum=0, sum2=0; long nw=0;
  for(ic=0; ic<N; ic++){
    double sumc=0, sumc2=0; float *yy=y[ic];
    for(i=0; i<nn[ic]; i++){sumc+=*yy; sumc2+=(*yy)*(*yy); yy++;}
    nw+=nn[ic]; sum+=sumc; sum2+=sumc2;
  }
  sum/=nw; sum2=(sum2-nw*sum*sum)/(nw-1);
  *sd=sqrt(sum2);
  return(sum);
}

void Rescale_counts(float **y, long *nn, int N, float scale){
  int ic, i;
  for(ic=0; ic<N; ic++){
    float *yy=y[ic];
    for(i=0; i<nn[ic]; i++){*yy *= scale; yy++;}
  }
}

void Normalize_reads(float *y, int nn, int BOXSIZE){
  float min_y=100000, *yy=y, coeff;
  double sum=0; int i;
  for(i=0; i<nn; i++){
    sum+=*yy;
    if((*yy<min_y)&&(*yy>0))min_y=*yy; yy++;
  }
  sum/=nn;
  if(sum/min_y > BOXSIZE){coeff= 1./(min_y*BOXSIZE);}
  else{coeff=1./min_y;}

  yy=y;
  for(i=0; i<nn; i++){
    (*yy) = ((*yy)*coeff); yy++;
  }
}

void Sum_boxes(float *y, long *x, long n, long *x1, long n1)
{
  int m=x1[0]/x[0];
  printf("Summing %d boxes (from step %d to step %d)\n",m,x[0], x1[0]);
  if((m*x[0])!=x1[0]){
    printf("WARNING, the step ratio %f is not integer!\n", (float)x1[0]/x[0]);
    //exit(8);
  }
  long *xptr=x, *x1ptr=x1, i, j=0;
  float *yptr=y, *y1ptr=y;
  for(i=0; i<n1; i++){
    //printf("%ld %ld\n", *x1ptr, *xptr); if(i==10)exit(8);
    *y1ptr=*yptr; yptr++; xptr++; j++;
    while(*xptr<=*x1ptr){
      *y1ptr+=*yptr; yptr++; xptr++; j++; if(j>=n)goto fill_x;
    }
    x1ptr++; y1ptr++;
  }
 fill_x:
  xptr=x, x1ptr=x1;
  for(i=0; i<n1; i++){
    *xptr=*x1ptr; xptr++; x1ptr++;
  }
}

