// prfreq.c    last update: Nov 7, 2007
//             code owner: Adam Boyko
//
// samplesize - number of individuals in derived pop sample
// Ne - effective population size of curr population (i.e. Nc)
// THETA - mutation param (4*Nc*mu)
// TAU - time since pop expansion (units: 2*Nc*gen time)
// OMEGA - magnitude of expansion: OMEGA = Ne (ancest) / Ne (curr)
// u[0..2*Ne] - distrib of allele freq ([0..1] scaled by 2*Nc)
// NOTE: GAMMA/PHI in terms of Ncurr (not Ne)
//
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>
#include <time.h>
#include <string.h>
#include "nrutil.h"
#include "nrutil.c"
#include "myrandom.c"
#include <gsl/gsl_cdf.h>


// global parameters
#define DEBUG
#define max_pieces 14
#define max_dims 5
//#define CALC_THETA
//#define EXTRAPOLATE0

#define ADD_NEUT_MASS
double na[41][151];
int __NEUT = 1;


// this is a hack (used with EXTRAPOLATE0); neutfile contains (150+1)x(40+1) values for neutral point mass; must have 150x40 steps in selection grid
double neut(double p1, double p2) {

  if (__NEUT) {
    __NEUT = 0;
    FILE *neutfile;
   neutfile = fopen("/Users/arb83/prfreq/nov/afr_infer_neut_lognorm.txt", "r");
    char buf[1027];
    int i,j;
    for(i=0;i<=150;i++) {
      fgets(buf, 1027, neutfile);
      sscanf(buf,"%lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n",&na[0][i],&na[1][i],&na[2][i],&na[3][i],&na[4][i],&na[5][i],&na[6][i],&na[7][i],&na[8][i],&na[9][i],&na[10][i],&na[11][i],&na[12][i],&na[13][i],&na[14][i],&na[15][i],&na[16][i],&na[17][i],&na[18][i],&na[19][i],&na[20][i],&na[21][i],&na[22][i],&na[23][i],&na[24][i],&na[25][i],&na[26][i],&na[27][i],&na[28][i],&na[29][i],&na[30][i],&na[31][i],&na[32][i],&na[33][i],&na[34][i],&na[35][i],&na[36][i],&na[37][i],&na[38][i],&na[39][i],&na[40][i]);
    }
    for(j=0;j<=150;j++) {
      for(i=0;i<=40;i++) printf("%lf ",na[i][j]);
      printf("\n");
    }
    fclose(neutfile);
  }

  int i1 = (p1+0.00001)*20.0 - 82;   //afr lognorm
  int i2 = (p2+0.00001)*50.0 - 225;   //afr lognorm
  return na[i1][i2];
}



// returns the prob dens fxn of N(M,SD) at x
double pdf_norm(const double x, const double M, const double SD) {
  if (SD < -1.0e-7) {
    printf("pdf_norm error: SD < 0! (SD = %f)\n", SD);
    exit(1);
  } else if (SD < 1.0e-7) return 1.0;
  return (exp(-0.5*(x-M)*(x-M)/(SD*SD))/(SD*sqrt(2.0*M_PI)));
}

// returns the prob dens fxn of LOGNORM(M,SD) at x
double pdf_lognorm(const double x, const double M, const double SD, const double loc) {
  if (SD < -1.0e-7) {
    printf("pdf_lognorm error: SD < 0! (SD = %f)\n", SD);
    exit(1);
  } else if (SD < 1.0e-7) return 1.0;
  if (x-loc < -1.0e-7) {
    printf("pdf_lognorm error: x < 0! (x = %f)\n", x);
    exit(1);
  } else if (x-loc < 1.0e-8) return 0.0;
  // return exp(pdf_norm(x, M, SD));
  return (exp(-0.5*pow((log(x-loc)-M)/SD,2))/((x-loc)*SD*sqrt(2.0*M_PI)));
}

// returns the prob dens fxn of GAMMA(h,theta) at x
double pdf_gamma(const double x, const double h, const double _theta) {
  if (x<0.0) {
    printf("Error in pdf_gamma: x<0: x = %lf\n", x);
    exit(1);
  } else if (h<0.0) {
    printf("Error in pdf_gamma: h<0: h = %lf\n", h);
    exit(1);
  }
  return (pow(x,h-1.0)*exp(-x/_theta)/(exp(gammln(h))*pow(_theta,h)));
}


char outgridfilename[128];
int fixed, poisson, samplesize, num_dims, num_pieces, int_type[max_pieces];
double Ne, THETA, tdiv, TAU, TAU_B, OMEGA, OMEGA_B, GAMMA;
double minOMEGA, maxOMEGA, stepOMEGA, minTAU, maxTAU, stepTAU;
double minOMEGA_B, maxOMEGA_B, stepOMEGA_B, minTAU_B, maxTAU_B, stepTAU_B;
double bestOMEGA, bestTAU, bestOMEGA_B, bestTAU_B;
double minOMEGA_B2, maxOMEGA_B2, bestOMEGA_B2, stepOMEGA_B2, OMEGA_B2;
double minTAU_B2, maxTAU_B2, bestTAU_B2, stepTAU_B2, TAU_B2;
int arsize, iter, currj=1;
int distrib, read_in_snps, run_type, input_grids, output_grids, num_iter;
double *u, *rhs, *lft, *mid, *rt, *D, *Q;  //*mtmp;
double delx, delt, al, be, gam, gam2, max_eps;
double minp[max_dims], maxp[max_dims], step[max_dims], param[max_dims];
double min_int[max_pieces], max_int[max_pieces];
double tnew, tnew2, tot_snps;
int newfreq, newfreq2;
double **esfs, *sfs, *hesfs, **param_est;
FILE *outfile, *outgridfile;

double *full_exp_sfs;
double *neut_exp_sfs;
// definitions for Romberg integration
//#define EPS 5.0e-5
//#define EPS 1.0e-5 //originally, boyko had 1e-9
//#define EPS2 1.0e-5 //originally, boyko had 1e-9
//#define JMAX 15
//#define JMAXP (JMAX+1)
//#define GMAX 14348907       // GMAX = 3^JMAX for midpnt, midinf, etc; originally was 65536


#define EPS 5.0e-5 //originally, boyko had 1e-9, later set to 1.0e-5, now set to 5.0e-5 by CHRISTIAN HUBER
#define EPS2 5.0e-5 //originally, boyko had 1e-9, later set to 1.0e-5, now set to 5.0e-5 by CHRISTIAN HUBER
#define JMAX 12 //originally, boyko had 15, but took extremely long for gamma distr with large alpha and beta CHRISTIAN HUBER
#define JMAXP (JMAX+1)
#define GMAX 65536 // try also smaller GMAX this time CHRISTIAN HUBER

#define K 5
int OUTGROUP = 0;

#ifdef DEBUG
int iright=0, iwrong=0;
#endif //DEBUG

double transient(double x, double *dist);
void tridag(void);
double PRFHazardSC(int count, double *dist);
double expectedsfs1(double g);
double qromo1(double a, double b, double (*choose)(double, double, int));
double midpnt0(double a, double b, int n, int count, double *dist);
double midpnt(double a, double b, int n);
double midinf(double aa, double bb, int n);
double midexp(double aa, double bb, int n);
double midexpN(double aa, double bb, int n);
double mrate(double x, int count, double *dist);
void polint(double xa[], double ya[], int n, double x, double *y, double *dy);
double weight(double _gamma);

int currParam=0;
// THESE 3 LINES REPLACE HEADER TO ELIMINATE COMPILER WARNING, BUT 2% SLOWDOWN
int comp(const void *ptrA, const void *ptrB) {
   const double *da = (const double *) ptrA;
   const double *db = (const double *) ptrB;
//int comp(const double *da, const double *db) {
//   return (da[currParam] > db[currParam]) - (da[currParam] < db[currParam]);
}

//hesfs contains an epected freq spec //kel comment

double LL(int logl) {
  int maxj = ((fixed==1)||(fixed==2)) ? samplesize : samplesize-1;
  int i, j, counter;
  switch ( distrib ) {
  case 0 : for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(0.0);
           break;
  case 1 : for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(param[0]);
           break;
  case 2 : for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(0.0);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(param[1]);
           break;
  case 3 : for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(0.0);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(param[2]);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(param[3]);
           break;
  case 16: for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(0.0);
           break;
  case 17: for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(0.0);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(param[1]);
           break;

  case 13: if (param[0]+param[1]+param[2] > 1.0)
             for(currj=0;currj<=samplesize;currj++) hesfs[currj]=1.0;
           else {
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(-500.0);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(-50.0);
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(-5.0);
           if (param[3] == 0.0) for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(0.0);
           else {  //printf("weight before qromo(%f,%f,%f,%f)=%f\n", param[0], param[1], param[2], param[3], hesfs[0]);
                   for(currj=0;currj<=samplesize;currj++)
                       hesfs[currj]+=qromo1(0.0,param[3],midpnt);
                   //printf("weight after qromo(%f,%f,%f,%f)=%f\n", param[0], param[1], param[2], param[3], hesfs[0]);
                }
           } //else
           break;

  default : if ((distrib == 4)&&(param[1] < 1.0e-7))
              for(currj=0;currj<=samplesize;currj++) hesfs[currj]=expectedsfs1(param[0]);
           else {
           for(currj=0;currj<=samplesize;currj++) hesfs[currj]=0.0;
           for(i=0; i<num_pieces; i++) {
             switch ( int_type[i] ) {
             case 1: for(currj=0;currj<=samplesize;currj++)
                       hesfs[currj]+=qromo1(min_int[i],max_int[i],midpnt);
                     break;
             case 2: for(currj=0;currj<=samplesize;currj++)
                       hesfs[currj]+=qromo1(min_int[i],max_int[i],midinf);
                     break;
             case 3: for(currj=0;currj<=samplesize;currj++)
                       hesfs[currj]+=qromo1(min_int[i],max_int[i],midexp);
                     break;
             case 4: for(currj=0;currj<=samplesize;currj++)
                       hesfs[currj]+=qromo1(min_int[i],max_int[i],midexpN);
                     break;
             default: printf("Warning: ignoring int_type=%d (param %d)\n", int_type[i], i); break;
             } //switch
//             printf("int %d cumm weight = %lf\n", i, hesfs[0]); //COMMENT
           } //for
           } //else
           break;
  } //switch

  if (distrib == 5) hesfs[0] += 1.0-param[2];
  else if (distrib == 14) for(currj=0;currj<=samplesize;currj++) hesfs[currj]+=expectedsfs1(param[3]);
  else if (distrib == 12) for(currj=0;currj<=samplesize;currj++) {
    hesfs[currj]+=expectedsfs1(0.0);
  }

//  printf("\ttotal weight = %lf (out of 1.0)", hesfs[0]); //COMMENT
  #ifdef EXTRAPOLATE
    for(j=1;j<=samplesize;j++) hesfs[j] /= hesfs[0];
  #endif
  #ifdef EXTRAPOLATE0
//    double w=0.112004;  // AEW PGA AFR
    double w=0.041726;   // AEW PGA CAU
    for(currj=0;currj<=samplesize;currj++) hesfs[currj] += w*expectedsfs1(0.0);
  #endif
	
double ww=0.0;

  #ifdef ADD_NEUT_MASS
  
 // ww=gsl_cdf_gamma_P(0.00001,param[0],param[1]);

switch ( distrib ) {
  case 6 : ww=gsl_cdf_gamma_P(0.00001,param[0],param[1]);
	//fprintf(stderr,"in_break: %lf\n",ww);

           break;
  case 9 : ww=gsl_cdf_lognormal_P(0.00001,param[0],param[1]);
           break;
  default : ww=0.0;
           break;
}


    //for(currj=0;currj<=samplesize;currj++) hesfs[currj] += w*expectedsfs1(0.0);

#endif

  for(counter=1;counter<=samplesize;counter++) full_exp_sfs[counter]=hesfs[counter] + ww*neut_exp_sfs[counter]; //put outside the #ifdef so that full_exp_sfs_will always be defined...

//double ww=0.025;
  printf("tot weight = %lf\t",hesfs[0]);   // provides a check; should be 1 for most distributions
  double ll = 0.0;
/*
  if (logl != 0) {        // calculate negative log-likelihood
    if (poisson == 0) {   // multinomial likelihood calculation
      double sum = 0.0;
      for(j=1; j<=maxj; j++) sum += hesfs[j];
      for(j=1; j<=samplesize; j++) hesfs[j] *= (tot_snps/sum);
//      for(j=1; j<=samplesize; j++) printf("sfs[%i]=%f, hesfs[%i]=%f\n",j,sfs[j],j,hesfs[j]);  //COMMENT
      if (fixed == 3) {                                         //folded sfs
        for(j=1; j<=((samplesize-1)/2); j++) {
          ll -= (sfs[j]+sfs[samplesize-j]) *
                log(hesfs[j]+hesfs[samplesize-j]);
        }
        if (samplesize%2 == 0) {
          ll -= sfs[samplesize/2]*log(hesfs[samplesize/2]);
        }
      } else for(j=1; j<=maxj; j++) ll -= sfs[j]*log(hesfs[j]);  //unfolded
    } else {               // poisson likelihood calculation
      if (fixed == 3) {
        for(j=1; j<=((samplesize-1)/2); j++) {
          ll -= ((sfs[j]+sfs[samplesize-j]) *
                log(hesfs[j]+hesfs[samplesize-j]) -
                (hesfs[j]+hesfs[samplesize-j]));
        }
        if (samplesize%2 == 0) ll -= (sfs[samplesize/2]*log(hesfs[samplesize/2])-hesfs[samplesize/2]);
      } else for(j=1; j<=maxj; j++) ll -= (sfs[j]*log(hesfs[j])-hesfs[j]);
    }
  //} else for(j=1; j<=samplesize; j++) sfs[j] = hesfs[j];
  } else for(j=1; j<=samplesize; j++) sfs[j] = hesfs[j]+ww*neut_exp_sfs[j];
*/ 
  if (logl != 0) {        // calculate negative log-likelihood
    if (poisson == 0) {   // multinomial likelihood calculation
      double sum = 0.0;
      //for(j=1; j<=maxj; j++) sum += hesfs[j];
      for(j=1; j<=maxj; j++) sum += full_exp_sfs[j];
      //for(j=1; j<=samplesize; j++) hesfs[j] *= (tot_snps/sum);
      for(j=1; j<=samplesize; j++) full_exp_sfs[j] *= (tot_snps/sum);
//      for(j=1; j<=samplesize; j++) printf("sfs[%i]=%f, hesfs[%i]=%f\n",j,sfs[j],j,hesfs[j]);  //COMMENT
      if (fixed == 3) {                                         //folded sfs
        for(j=1; j<=((samplesize-1)/2); j++) {
          //ll -= (sfs[j]+sfs[samplesize-j]) * log(hesfs[j]+hesfs[samplesize-j]);
          ll -= (sfs[j]+sfs[samplesize-j]) * log(full_exp_sfs[j]+full_exp_sfs[samplesize-j]);
        }
        if (samplesize%2 == 0) {
          //ll -= sfs[samplesize/2]*log(hesfs[samplesize/2]);
          ll -= sfs[samplesize/2]*log(full_exp_sfs[samplesize/2]);
        }
      //} else for(j=1; j<=maxj; j++) ll -= sfs[j]*log(hesfs[j]);  //unfolded
      } else for(j=1; j<=maxj; j++) ll -= sfs[j]*log(full_exp_sfs[j]);  //unfolded
    } else {               // poisson likelihood calculation
      if (fixed == 3) {
        for(j=1; j<=((samplesize-1)/2); j++) {
          //ll -= ((sfs[j]+sfs[samplesize-j]) * log(hesfs[j]+hesfs[samplesize-j]) - (hesfs[j]+hesfs[samplesize-j]));
          ll -= ((sfs[j]+sfs[samplesize-j]) * log(full_exp_sfs[j]+full_exp_sfs[samplesize-j]) - (full_exp_sfs[j]+full_exp_sfs[samplesize-j]));
        }
        //if (samplesize%2 == 0) ll -= (sfs[samplesize/2]*log(hesfs[samplesize/2])-hesfs[samplesize/2]);
        if (samplesize%2 == 0) ll -= (sfs[samplesize/2]*log(full_exp_sfs[samplesize/2])-full_exp_sfs[samplesize/2]);
      //} else for(j=1; j<=maxj; j++) ll -= (sfs[j]*log(hesfs[j])-hesfs[j]);
      } else for(j=1; j<=maxj; j++) ll -= (sfs[j]*log(full_exp_sfs[j])-full_exp_sfs[j]);
    }
  //} else for(j=1; j<=samplesize; j++) sfs[j] = hesfs[j];
  } else for(j=1; j<=samplesize; j++) sfs[j] = full_exp_sfs[j];
 printf("%f\t%f\t%f\t%f\t%lf\t%lf\tLL=\t%lf\t%lf\n",OMEGA,TAU,OMEGA_B,TAU_B,param[0],param[1],param[2],-ll);
  return ll;
}

void read_neut_sfs_file(char *neut_sfs_fname) {

  FILE *neut_sfsfile;
  neut_sfsfile = fopen(neut_sfs_fname, "r");
  if (neut_sfsfile == NULL) {
    printf("Error opening neut_sfs file: %s\n", neut_sfs_fname);
    exit(1);
  } //else printf("SNPfile %s opened successfully\n", SNPfname);

  int i;
  for(i=1; i<=samplesize; i++) neut_exp_sfs[i]=0.0;
  //tot_snps=0.0;

  for(i=1; i<=samplesize; i++) {
    fscanf(neut_sfsfile, "%lf", &(neut_exp_sfs[i]));
    if (neut_exp_sfs[i]<0.0) {
      printf("Error: sfs[%d] is negative! (sfs[%d]=%f)\n",i,i,sfs[i]);
      exit(1);
    } else if (neut_exp_sfs[i]==0.0) printf("Warning: neut_exp_sfs[%d] is zero!\n",i);
    //tot_snps += sfs[i];
  }
  fclose(neut_sfsfile);
}





void read_snp_file(char *SNPfname) {

  FILE *snpfile;
  snpfile = fopen(SNPfname, "r");
  if (snpfile == NULL) {
    printf("Error opening snpfile file: %s\n", SNPfname);
    exit(1);
  } else printf("SNPfile %s opened successfully\n", SNPfname);

  int i;
  for(i=1; i<=samplesize; i++) sfs[i]=0.0;
  tot_snps=0.0;

  for(i=1; i<=samplesize; i++) {
    fscanf(snpfile, "%lf", &(sfs[i]));
    if (sfs[i]<0.0) {
      printf("Error: sfs[%d] is negative! (sfs[%d]=%f)\n",i,i,sfs[i]);
      exit(1);
    } else if (sfs[i]==0.0) printf("Warning: sfs[%d] is zero!\n",i);
    tot_snps += sfs[i];
  }
  if ((fixed == 0)||(fixed == 3)) tot_snps -= sfs[samplesize];
  if (tot_snps <= 0.0) {
    printf("Error: no snps were observed in snp file %s\n", SNPfname);
    exit(1);
  }
  printf("tot snps observed = %lf\n", tot_snps);
  fclose(snpfile);
}


void read_ingrid_file(char *ingridfname) {

  FILE* ingridfile = fopen(ingridfname, "r");
  if (ingridfile == NULL) {
    printf("WARNING: Can't open input grid file %s; de novo calcs will be used...\n", ingridfname);
    iter = 0;
  } else {
    int i, ss;
    fscanf(ingridfile, "%i", &ss);
    if (ss == samplesize) {
      fscanf(ingridfile, "%i", &i);
      printf("iter = %i, samplesize = %i\n", i, ss);
      iter = i;
      for(i=0; i<iter; i++) {
        for(ss=0; ss<=samplesize; ss++) {
          fscanf(ingridfile, "%lf", &(esfs[i][ss]));
          if (i==0) printf("esfs[0][%i]=%f\n",ss,esfs[0][ss]);
        } //for ss
      } //for i
    } else printf("samplesizes not comparable: %i <> %i; de novo calcs used\n", ss, samplesize);
    fclose(ingridfile);
  } //else
}


double FindMaxLikelihood(int currIter) {

  if (run_type == 2) {
    printf("Gradient search is not yet supported in this version of prfreq; please use grid search\n");
    exit(1);
  }
  double maxL=9.9e99;
  double maxLparams[max_dims], s[max_dims];
  double currLL;
  int k;

  if (minOMEGA < maxOMEGA) stepOMEGA = (maxOMEGA-minOMEGA)/stepOMEGA;
  else {maxOMEGA = minOMEGA; stepOMEGA = 1.0;}
  if (minTAU < maxTAU) stepTAU = (maxTAU-minTAU)/stepTAU;
  else {maxTAU = minTAU; stepTAU = 1.0;}
  if (minOMEGA_B < maxOMEGA_B) stepOMEGA_B=(maxOMEGA_B-minOMEGA_B)/stepOMEGA_B;
  else {maxOMEGA_B = minOMEGA_B; stepOMEGA_B = 1.0;}
  if (minTAU_B < maxTAU_B) stepTAU_B = (maxTAU_B-minTAU_B)/stepTAU_B;
  else {maxTAU_B = minTAU_B; stepTAU_B = 1.0;}
  if (minOMEGA_B2 < maxOMEGA_B2) stepOMEGA_B2=(maxOMEGA_B2-minOMEGA_B2)/stepOMEGA_B2;
  else {maxOMEGA_B2 = minOMEGA_B2; stepOMEGA_B2 = 1.0;}
  if (minTAU_B2 < maxTAU_B2) stepTAU_B2 = (maxTAU_B2-minTAU_B2)/stepTAU_B2;
  else {maxTAU_B2 = minTAU_B2; stepTAU_B2 = 1.0;}
  for(k=0; k<num_dims; k++) {
    if (minp[k]<maxp[k]) s[k] = (maxp[k]-minp[k])/step[k];
    else {maxp[k] = minp[k]; s[k] = 1.0;}
  }
  for(k=num_dims; k<max_dims; k++) {maxp[k] = minp[k]; s[k] = 1.0;}
  printf("num_dims=%d:\n", num_dims);
  for(k=0;k<num_dims;k++) printf("minp=%f,maxp=%f,step=%f,s=%f\n",minp[k],maxp[k],step[k],s[k]);

  for(OMEGA=minOMEGA;OMEGA<=maxOMEGA;OMEGA+=stepOMEGA) {
   for(OMEGA_B=minOMEGA_B;OMEGA_B<=maxOMEGA_B;OMEGA_B+=stepOMEGA_B) {
    for(TAU_B=minTAU_B;TAU_B<=maxTAU_B;TAU_B+=stepTAU_B) {
     for(OMEGA_B2=minOMEGA_B2;OMEGA_B2<=maxOMEGA_B2;OMEGA_B2+=stepOMEGA_B2){
      for(TAU_B2=minTAU_B2;TAU_B2<=maxTAU_B2;TAU_B2+=stepTAU_B2) {
       for(TAU=minTAU;TAU<=maxTAU;TAU+=stepTAU) {
//        TAU = OMEGA*0.1593/0.4175;
//        OMEGA_B = OMEGA/0.4175;
          if ((minOMEGA < maxOMEGA)||(minTAU < maxTAU)||
              (minOMEGA_B < maxOMEGA_B)||(minTAU_B < maxTAU_B)||
              (minOMEGA_B2 < maxOMEGA_B2)||(minTAU_B2 < maxTAU_B2)) iter = 0;
//  printf("demography: %f %f %f %f %f %f\n",OMEGA,TAU,OMEGA_B,TAU_B,OMEGA_B2,TAU_B2); COMMENT
  for(param[0]=minp[0];param[0]<=maxp[0];param[0]+=s[0]) {
    for(param[1]=minp[1];param[1]<=maxp[1];param[1]+=s[1]) {
      for(param[2]=minp[2];param[2]<=maxp[2];param[2]+=s[2]) {
        for(param[3]=minp[3];param[3]<=maxp[3];param[3]+=s[3]) {
          for(param[4]=minp[4];param[4]<=maxp[4];param[4]+=s[4]) {
//            printf("param[0]=%lf, param[1]=%lf, param[2]=%lf\n",param[0],param[1],param[2]);
            currLL = LL(1);
//            printf("%f %f %f %f %f\n",OMEGA,TAU,OMEGA_B,TAU_B,currLL);     // show demographic inference
//            printf("%f %f %f %f\n",param[0],param[1],param[2],currLL);     // show selection inference
            if (currLL < maxL) {
              for(k=0; k<num_dims; k++) maxLparams[k] = param[k];
              bestOMEGA=OMEGA; bestTAU=TAU; bestOMEGA_B=OMEGA_B; bestTAU_B=TAU_B; bestOMEGA_B2=OMEGA_B2; bestTAU_B2=TAU_B2;
              maxL = currLL;
            } //if
          }
        }
      }
    }
  }
  } } } } } }
  for(k=0; k<num_dims; k++) {
    param_est[currIter][k] = maxLparams[k];
    printf("param_est[%d][%d] = %f\n",currIter,k,param_est[currIter][k]);
  }
  printf("maxL = %f\n", maxL);
  return maxL;
}

// returns the # of diffs that appear fixed due to insufficient sampling
//  of outgroup (outgroup sample size = 1)
double outgroup_fixed() {
  if (fixed == 2) return 0.0;
  OUTGROUP = 1;
  int old_samplesize = samplesize;
  samplesize = 1;
  LL(0);
  samplesize = old_samplesize;
  OUTGROUP = 0;
  return sfs[1];
}


int main(int argc, char* argv[]) {

  int num_epochs;
  char tempchar, inputline[128], SNPfilename[128], ingridfilename[128], NEUTfilename[500];
  struct timeval start_time, end_time;
  double *exp_sfs;
  FILE *settings;

	strcpy(NEUTfilename, "Boyko_AFR_neut_exp_sfs.txt");
  if ((argc < 3)||(argc > 6)) {
    printf("Usage: settings_filename output_filename [input_snps_filename input_grid_filename output_grid_filename]\n");
    exit(1);
  }
  printf ("argc = %d\n",argc);
  gettimeofday(&start_time, NULL);
  settings = fopen(argv[1], "r");
  if (settings == NULL) {
    printf("Error opening settings file: %s\n", argv[1]);
    exit(1);
  } else {
    #ifdef DEBUG
    printf("settings file %s is open\n", argv[1]);
    #endif //DEBUG
  }
  outfile = fopen(argv[2], "w");
  if (outfile == NULL) {
    printf("Error opening output file: %s\n", argv[2]);
    exit(1);
  } else {
    #ifdef DEBUG
    printf("output file %s is open\n", argv[2]);
    #endif //DEBUG
  }

  fscanf(settings, "fixed = %i\n", &fixed);    // include fixed diffs?
  fscanf(settings, "poisson = %i\n", &poisson);  // poisson or multinomial?
  fscanf(settings, "samplesize = %i\n", &samplesize);
  fscanf(settings, "Ne = %lf\n", &Ne);
  fscanf(settings, "THETA = %lf\n", &THETA);
  fscanf(settings, "tdiv = %lf\n", &tdiv);
  if ((samplesize<=0)||(Ne<=0.0)||(THETA<0.0)||(tdiv<=0.0)) {
    printf("ERROR: samplesize = %d, Ne = %lf, THETA = %lf, tdiv = %lf\n",
     samplesize,Ne,THETA,tdiv);
    exit(1);
  }
  printf("POISSON = %d\n",poisson);
  delx=1.0/(2.0*Ne);
  delt=1.0/(32.0*Ne);
  al=delt/(4.0*delx);
  be=delt/(4.0*delx*delx);
  gam=delt*delx*(1.0-delx/2.0)/8.0;
  gam2=delt*(1.0-delx)/4.0;
  arsize = 2*Ne;
  fscanf(settings, "num_epochs = %i\n", &num_epochs);
  if ((num_epochs < 1)||(num_epochs > 4)) {
    printf("Warning: num_epochs = %i not valid, changed to 1\n", num_epochs);
    num_epochs = 1;
  }
  if (num_epochs == 1) {
    minTAU=0.0; minOMEGA=1.0; maxTAU=0.0; maxOMEGA=1.0; stepTAU=1.0; stepOMEGA=1.0;
  } else {
    fscanf(settings, "TAU = %lf %lf %lf\n", &minTAU, &maxTAU, &stepTAU);
    fscanf(settings, "OMEGA = %lf %lf %lf\n", &minOMEGA, &maxOMEGA, &stepOMEGA);
  }
  if (num_epochs <= 2) {
    minTAU_B=0.0; minOMEGA_B=1.0; maxTAU_B=0.0; maxOMEGA_B=1.0;
    stepTAU_B=1.0; stepOMEGA_B=1.0;
  } else {
    fscanf(settings, "TAU_B = %lf %lf %lf\n", &minTAU_B, &maxTAU_B, &stepTAU_B);
    fscanf(settings, "OMEGA_B = %lf %lf %lf\n", &minOMEGA_B, &maxOMEGA_B, &stepOMEGA_B);
  }
  if (num_epochs <= 3) {
    minTAU_B2=0.0; minOMEGA_B2=1.0; maxTAU_B2=0.0; maxOMEGA_B2=1.0;
    stepTAU_B2=1.0; stepOMEGA_B2=1.0;
  } else {
    fscanf(settings, "TAU_B2 = %lf %lf %lf\n", &minTAU_B2, &maxTAU_B2, &stepTAU_B2);
    fscanf(settings, "OMEGA_B2 = %lf %lf %lf\n", &minOMEGA_B2, &maxOMEGA_B2, &stepOMEGA_B2);
  }
  TAU = minTAU; OMEGA = minOMEGA;
  TAU_B = minTAU_B; OMEGA_B = minOMEGA_B;
  TAU_B2 = minTAU_B2; OMEGA_B2=minOMEGA_B2;
  if ((TAU_B2 < 0.0)||(TAU_B < 0.0)||(TAU < 0.0))
    printf("Warning: TAU, TAU_B, TAU_B2 should not be negative...\n");
  if (OMEGA < 0.0) printf("Warning: OMEGA should not be negative...\n");
  if ((TAU < delt)||(OMEGA == 1.0)) printf("RUNNING STATIONARY DEMOGRAPHIC MODEL...\n");
  else {
    if ((TAU_B < delt)||(OMEGA_B == 1.0)) {
      if (OMEGA < 1.0) printf("RUNNING POPULATION EXPANSION MODEL...\n");
      else printf("RUNNING POPULATION CONTRACTION MODEL...\n");
    } else {
      if (OMEGA_B > 1.0) printf("RUNNING INVERSE BOTTLENECK MODEL...\n");
      else printf("RUNNING BOTTLENECK MODEL...\n");
    }
  }
  if (TAU_B+TAU_B2 > TAU) {
    printf("Warning: TAU(bottleneck) > TAU(bottleneck+expansion)!\n");
    printf("bottleneck simulating will stop after time = TAU (%f)\n",TAU);
  }

  esfs=dmatrix(0,GMAX,0,samplesize);
  max_eps = 0.0;
  hesfs=dvector(0,samplesize);
  fscanf(settings, "distrib = %i\n", &distrib);
  switch ( distrib ) {
  case  0: num_dims = 0; break;  // neutral
  case  1: num_dims = 1; break;  // point-mass (gamma)
  case  2: num_dims = 2; break;  // 2-point (pneutral, gamma)
  case  3: num_dims = 4; break;  // 3-point (pneutral, pneg, gamneg, gampos)
  case  4: num_dims = 2; break;  // normal (mean, stdev)
  case  5: num_dims = 3; break;  // weighted norm (mean, stdev, weight)
  case  6: num_dims = 2; break;  // neg gamma (alpha, beta)
  case  7: num_dims = 3; break;  // reflected gamma (alpha, beta, ppos)
  case  8: num_dims = 3; break;  // neg gamma offest (alpha, beta, maxgam)
  case  9: num_dims = 2; break;  // neg lognormal (scale m, shape sigma)
  case 10: num_dims = 3; break;  // neg lognorm offset (scale, shape, location)
  case 11: num_dims = 5; break;  // 2-norm (mu1, mu2, sd1, sd2, pneg)
  case 12: num_dims = 3; break;  // neg exponential (lambda, pneu, pleth)
  case 13: num_dims = 4; break;  // 3pt + pos selection (pr(-500),pr(-50),pr(-5),maxgam)
  case 14: num_dims = 4; break;  // weighted norm2 (mean, stdev, weight, gam)
  case 16: num_dims = 1; break;  // lethal/neutral (pneu)
  case 17: num_dims = 3; break;  // lethal/const/neutral (pneu,gam,pgam)
  case 22: num_dims = 3; break;  // relected exp + lethal (lambda,pleth,ppos)
  default: printf("Error: distribution type %d not recognized!\n");
           exit(1); break;
  }
  #ifdef DEBUG
    printf("distrib = %d, num_dims = %d\n", distrib, num_dims);
  #endif //DEBUG
  int i, j;
  for(i=0; i<num_dims; i++) {
    fscanf(settings, "P %lf %lf %lf\n",&(minp[i]),&(maxp[i]),&(step[i]));
// TODO: insert bound-checking
    printf("dim %d: min=%lf, max=%lf, step=%lf\n",i,minp[i],maxp[i],step[i]);
  }
  num_pieces = 0;
  for(i=0; i<max_pieces; i++) {
    if ((fgets(inputline, 128, settings))&&(inputline[0]=='I')) {
      sscanf(inputline, "%c %d %lf %lf\n", &tempchar,
          &(int_type[i]), &(min_int[i]), &(max_int[i]));
      num_pieces++;
    } else break;
  }
  num_pieces = i;
  printf("Integration routine will be: \n");
  for(i=0; i<num_pieces; i++)  {
    switch ( int_type[i] ) {
    case 1: printf("\tqromo1(%lf, %lf, midpnt)\n",min_int[i],max_int[i]);break;
    case 2: printf("\tqromo1(%lf, %lf, midinf)\n",min_int[i],max_int[i]);break;
    case 3: printf("\tqromo1(%lf, %lf, midexp)\n",min_int[i],max_int[i]);break;
    case 4: printf("\tqromo1(%lf, %lf, midexpN)\n",min_int[i],max_int[i]);break;
    default: printf("ERROR: unknown integration type!\n"); exit(1); break;
    } //switch
  } //for
  printf("\n");
  if (inputline[0]=='r') read_in_snps = (inputline[15] != '0');
  else fscanf(settings, "read_in_snps = %i\n", &read_in_snps);
  int curr_arg = 3;
  if (read_in_snps != 0) strcpy(SNPfilename, argv[curr_arg++]);
  fscanf(settings, "run_type = %i\n", &run_type); // 1=grid, 2=grad, 3=power
  if ((run_type < 1)||(run_type > 5)) {           // 4=mkprf, 5=mkprfint
    printf("Error: run_type %d is not supported; must be 1 thru 5\n", run_type);
    exit(1);
  } else if ((run_type >= 4)&&(distrib != 4)) {
    printf("Error: distrib must be normal (4) for mkprf/mkprfint mode!\n");
    exit(1);
  }
  if ((run_type >= 4)&&(fixed < 1)) fixed = 1;
  fscanf(settings, "input_grids = %i\n", &input_grids);
  if (input_grids != 0) {
    // TODO: insert error checking using argc?
    strcpy(ingridfilename, argv[curr_arg++]);
    read_ingrid_file(ingridfilename);
  } else iter = 0;
  fscanf(settings, "output_grids = %i\n", &output_grids);
  if (output_grids != 0) strcpy(outgridfilename, argv[curr_arg++]);
  fscanf(settings, "num_iters = %i\n", &num_iter);
  if (run_type < 3) num_iter = 1;
  if ((run_type <= 3)&&(THETA == 0.0)) { THETA = 1.0; num_iter = 0; }
  fclose(settings);

  // TODO: insert bound checking

  fprintf(outfile,"samplesize=%d,Ne=%f,THETA=%f,TAU=%f,tdiv=%f,OMEGA=%f\n",
     samplesize,Ne,THETA,TAU,tdiv,OMEGA);
  sfs=dvector(1,samplesize);

    neut_exp_sfs=dvector(1,samplesize);
    full_exp_sfs=dvector(1,samplesize);
  D=dvector(0,2*Ne);
  Q=dvector(0,2*Ne);
  u=dvector(0,arsize); //arsize is 2Ne
  //mtmp=dvector(0,2*Ne);
  rhs=dvector(0,2*Ne);
  rt=dvector(0,2*Ne);
  lft=dvector(0,2*Ne);
  mid=dvector(0,2*Ne);

#ifdef ADD_NEUT_MASS
	//strcpy(NEUTfilename, "Boyko_AFR_neut_exp_sfs.txt");
	strcpy(NEUTfilename, argv[curr_arg++]);

	read_neut_sfs_file(NEUTfilename);
  
for(i=1; i<=samplesize; i++)
	{
		fprintf(stderr,"exp_sfs:\t%lf\n",neut_exp_sfs[i]);
	}
#endif

  if (read_in_snps) read_snp_file(SNPfilename);
  else if (run_type <= 3) {
    for(i=0; i<num_dims; i++) param[i]=minp[i];
    LL(0);
    double poly=0.0;
    for(i=1; i<samplesize; i++) poly += sfs[i];
    fprintf(outfile,"expected poly:div = %f : %f\n", poly, sfs[samplesize]);
    for(i=1; i<=samplesize; i++) fprintf(outfile,"expected sfs[%d]=%lf\n",i,sfs[i]);
    if (run_type < 3) num_iter = 0;
  }    // TODO: reinstate later once this is fixed up

// added 10/5 for calculating THETA (this is an undocumented hack)
  #ifdef CALC_THETA
     double poly=0.0;
     for(i=1; i<samplesize; i++) poly += sfs[i];
     printf("OMEGA = %f, TAU = %f\n",OMEGA,TAU);
     printf("OMEGA_B = %f, TAU_B = %f\n",OMEGA_B,TAU_B);
     printf("OMEGA_B2 = %f, TAU_B2 = %f\n",OMEGA_B2,TAU_B2);
     expectedsfs1(0.0);
     double sumt=0.0;
//   for(i=0; i<=2*Ne; i+=10) printf("u[%f]=%f\n",0.5*i/Ne,u[i]);
     for(i=1;i<samplesize;i++) {
       printf("F(%d,%d) = %f\n", i, samplesize, esfs[0][i]/THETA);
       sumt += (esfs[0][i]/THETA);
     }
     printf("denom = %f\n", sumt);
     printf("thetaR = %f\n", 2.5*poly/sumt);
     exit(1);
  #endif //CALC_THETA

  if (run_type > 3) {
    int sample = (int)THETA;
    if (THETA <= 0.0) THETA = 100.0;
    double FS, SS, FR, SR;
    param[0] = 0.0;
    param[1] = 0.0;
    LL(0);
    SS = 0.0;
    for(j=0; j<samplesize; j++) SS += sfs[j];
    FS = sfs[samplesize];
    FS += outgroup_fixed();
    for(i=0; i<num_iter; i++) {
      param[0] = rnormal(minp[0], minp[1]);
      if (run_type == 4) param[1] = 0.0;     // mkprf (stdev = 0)
//      else param[1] = exp(rnormal(0.0, 1.0)); // mkprfkint
      else param[1] = 1.0; // mkprfkint (fixed variance)
      LL(0);
      SR = 0.0;
      for(j=0; j<samplesize; j++) SR += sfs[j];
      FR = sfs[samplesize];
      FR += outgroup_fixed();
      if (run_type == 4) fprintf(outfile,"M%f\t",param[0]);
      else fprintf(outfile,"M%f S%f\t",param[0],param[1]);
      if (sample == 0.0)      // output expected values
        fprintf(outfile,"%f\t%f\t%f\t%f\t%d\t1\t500 50 0.1\n",
                FS,SS,FR,SR,samplesize);
      else                    // output simulated values
        fprintf(outfile,"%f\t%f\t%f\t%f\t%d\t1\t500 50 0.1\n",
                poidev(FS),poidev(SS),poidev(FR),poidev(SR),samplesize);
    } // for
  } else {
    param_est = dmatrix(0,num_iter+1,0,num_dims);
    if (run_type == 3) {
      exp_sfs=dvector(1,samplesize);
      for(i=0; i<=samplesize; i++) exp_sfs[i] = sfs[i];
    }
    for(i=0; i<num_iter; i++) {
      if (run_type == 3) {
        for(j=0; j<=samplesize; j++) sfs[j] = poidev(exp_sfs[j]);
      }
      double loglike = FindMaxLikelihood(i);
      printf("bestOMEGA=%f, bestTAU=%f\n",bestOMEGA,bestTAU);
      if (num_epochs > 1) {
        fprintf(outfile,"OMEGA = %lf\tTAU = %lf\n",bestOMEGA,bestTAU);
        OMEGA=bestOMEGA; TAU=bestTAU;
        if (num_epochs > 2) {
          fprintf(outfile,"OMEGA_B = %e\tTAU_B = %e\n",bestOMEGA_B,bestTAU_B);
          OMEGA_B=bestOMEGA_B; TAU_B=bestTAU_B;
          if (num_epochs > 3) {
            fprintf(outfile,"OMEGA_B2 = %e\tTAU_B2 = %e\n",bestOMEGA_B2,bestTAU_B2);
            OMEGA_B2=bestOMEGA_B2; TAU_B2=bestTAU_B2;
          }
        }
      }
      fprintf(outfile,"%lf %lf ",bestOMEGA,bestTAU);
      for(j=0; j<num_dims; j++) {
        fprintf(outfile,"%e\t",param_est[i][j]);
        param[j] = param_est[i][j];
      }
      fprintf(outfile,"\n");
      printf("OMEGA=%f, TAU = %f, OMEGA_B = %f, TAU_B = %f\n",OMEGA,TAU,OMEGA_B,TAU_B);
      LL(1);                        // TODO: double-check that this is right
      for(j=1; j<=samplesize; j++)
        fprintf(outfile,"expected sfs[%d]=%lf\n",j,sfs[j]);
      if (run_type != 3) {
        printf("\nmax LL=%f\n", -loglike);
        fprintf(outfile, "max LL=%f\n", -loglike);
      }
    }
    if ((run_type == 3)&&(num_iter > 0)) {
      fprintf(outfile,"\nPARAM\t\t95%% C.I.\t\t[min, max]\n");
      for(currParam=0; currParam<num_dims; currParam++) {
        double dmean=0.0,dstdev=0.0,dloci,dhici;
        qsort(param_est,num_iter,sizeof(double)*num_dims,comp);
        for(i=0;i<num_iter;i++) dmean+=param_est[i][currParam]/num_iter;
        for(i=0;i<num_iter;i++) dstdev+=(param_est[i][currParam]-dmean)*
                                        (param_est[i][currParam]-dmean);
        dstdev=sqrt(dstdev/(num_iter-1.0));
        dloci=0.5*param_est[(int)(ceil(0.025*num_iter))][currParam]+
              0.5*param_est[(int)(floor(0.025*num_iter))][currParam];
        dhici=0.5*param_est[(int)(ceil(0.975*num_iter))-1][currParam]+
              0.5*param_est[(int)(floor(0.975*num_iter))-1][currParam];
        fprintf(outfile,"%i\t%6.3f\t%6.3f\t(%6.3f, %6.3f)\t[%6.3f, %6.3f]\n",currParam,dmean,dstdev,dloci,dhici,param_est[0][currParam],param_est[num_iter-1][currParam]);
      } //for
      free_dvector(exp_sfs,1,samplesize);
    } //if
    free_dmatrix(param_est,0,num_iter+1,0,num_dims);
  } // else

  if (output_grids) {
    FILE* outgridfile = fopen(outgridfilename, "w");
    if (outgridfile != NULL) {
      fprintf(outgridfile, "%i\n", samplesize);
      fprintf(outgridfile, "%i\n", iter);
      for(i=0; i<iter; i++) {
        for(j=0; j<=samplesize; j++) {
          fprintf(outgridfile,"%24.12e\n",esfs[i][j]);
        }
      }
      fclose(outgridfile);
    } else printf("Warning: could not open outgrid file %s; values not saved.\n", outgridfilename);
  } //if

  //free_dvector(mtmp,0,2*Ne);
  free_dvector(rhs,0,2*Ne);
  free_dvector(lft,0,2*Ne);
  free_dvector(mid,0,2*Ne);
  free_dvector(rt,0,2*Ne);
  free_dvector(D,0,2*Ne);
  free_dvector(u,0,arsize);
  free_dmatrix(esfs,0,GMAX,0,samplesize);
  free_dvector(hesfs,0,samplesize);
  gettimeofday(&end_time, NULL);   double usec = (end_time.tv_sec - start_time.tv_sec) + 0.000001*(end_time.tv_usec - start_time.tv_usec);
  printf("Total time was %f sec\n", usec);
  #ifdef DEBUG
  printf("Max error (eps) = %e\n", max_eps);
//  printf("iright = %d, iwrong = %d\n", iright, iwrong);    // check on search efficiency
  #endif
  return 0;

}


// interpolates num of snps with freq x given discretized freq distribution, dist
double transient(double x, double *dist) {
   double m,b;
   int low,high;
   low=(int) floor(x/delx);
   high=(int) ceil(x/delx);
   if (low==high) return dist[low];
   else if (low==0) return dist[1];
   else if (high==arsize+1) return dist[arsize];
   else {
      m=(dist[high]-dist[low])/delx;
      b=dist[low]-m*delx*low;
      // printf("m: %f, b: %f, out: %f\n",m,b,m*x+b);
      return m*x+b;
   }
}


double oldsize=-1.0;
// void pop_size(double): used to determine effective pop size at a given time;
//  automatically updates D, Q, lft, mid, and rt arrays appropriately


//This popsize funciton will get the transient distribution for different demogs.
void pop_size(double currtime) {
  double newsize;
  if (currtime <= 0.0) {
    newsize=OMEGA;
    printf("Warning: tridag called with time=%f\n",currtime);
  } else if (currtime <= TAU_B) {
    newsize=OMEGA_B;
  } else if (currtime <= TAU_B2+TAU_B) {
    newsize=OMEGA_B2;
  } else if (currtime <= TAU) {
    newsize = 1.0;
  } else {
    printf("Error in pop_size: time (%f) > TAU (%f)\n", currtime, TAU);
    exit(1);
  }

  if (newsize != oldsize) {
    int j;
    for (j=1; j<arsize; j++) {
      double q = j*delx;
      D[j] = q*(1.0-q)/newsize;// this is change due to drift
      Q[j] = 2.0*q*(1.0-q)*newsize*GAMMA*(q+0.5*(1.0-2.0*q)); //this is change due to selection
    } //for j
    rt[0]=-gam-gam2;
//  mid[0]=1.0+gam;         //(not necessary, never used)
    for (j=1; j<arsize; j++) {
      lft[j]=-al*Q[j-1]-be*D[j-1];
      mid[j]=1.0+2.0*be*D[j];
      rt[j]=al*Q[j+1]-be*D[j+1];
    } //for j
    mid[arsize]=1.0-gam;
    lft[arsize]=-gam-gam2;
    oldsize = newsize;
    newfreq = floor(1.0/OMEGA_B);
    tnew = 1.0/OMEGA_B - newfreq;
    if (newfreq > Ne/2) {
      printf("Error: Ne too small for given OMEGA_B!\n");
      exit(1);
    }
    newfreq2 = floor(1.0/OMEGA_B2);
    tnew2 = 1.0/OMEGA_B2 - newfreq2;
    if (newfreq2 > Ne/2) { //here the Ne is the one specified in the setting file--it's not the actual popualtion size.
      printf("Error: Ne too small for given OMEGA_B2!\n");
      exit(1);
    }
  }
}


double lasttime=0.0, lastOMEGA=0.0, lastOMEGA_B=0.0, lastTAU_B=0.0, lastOMEGA_B2=0.0, lastTAU_B2=0.0, lastGAMMA=0.0;

void tridag(void) {
  int j;
  double mtmp[20001];      //(for optimal efficiency)
  double t,tmp;
//  mtmp[0] = 1.0/mid[0];   //(not necessary, never used)
  mtmp[0] = 1.0/(1.0+gam);

  if ((OMEGA==lastOMEGA)&&(OMEGA_B==lastOMEGA_B)&&(TAU_B==lastTAU_B)&&(OMEGA_B2==lastOMEGA_B2)&&(TAU_B2==lastTAU_B2)&&(GAMMA==lastGAMMA)) {
    if (TAU <= lasttime) printf("Warning: current TAU (%f) <= last TAU (%f); tridag will not loop\n", TAU, lasttime);
  } else {
    lasttime = delt;
    lastOMEGA = OMEGA;
    lastOMEGA_B = OMEGA_B;
    lastTAU_B = TAU_B;
    lastOMEGA_B2 = OMEGA_B2;
    lastTAU_B2 = TAU_B2;
    if (GAMMA != lastGAMMA) {
      lastGAMMA = GAMMA;
      oldsize = -2.0;            // ensures rhs, etc are recalculated
    }
    double q;
    for (j=1; j<2*Ne; j++) {   // start with stationary distribution
      q = j*delx;
        if (GAMMA < -1.0e-7) {                 // negative selection
          u[j]=THETA*OMEGA*(exp(2.0*GAMMA*OMEGA*q)-exp(2.0*GAMMA*OMEGA))/(q*(1.0-q)*(1.0-exp(2.0*GAMMA*OMEGA)));
        } else if (GAMMA > 1.0e-7) {           // positive selection
          u[j]=THETA*OMEGA*(1.0-exp(-2.0*GAMMA*OMEGA*(1.0-q)))/(q*(1.0-q)*(1.0-exp(-2.0*GAMMA*OMEGA)));
        } else u[j] = THETA*OMEGA/q;           // neutral case
    } // for each allele frequency
    u[0]=0.0; u[arsize]=0.0;
  }

  for(t=lasttime; t<TAU; t+=delt) {
    pop_size(t);
    for (j=1;j<arsize;j++) {
      rhs[j]=(-u[j-1]*lft[j])-u[j]*(mid[j]-2.0)-(u[j+1]*rt[j]);
    }
    rhs[0]=u[0]*(1.0-gam)+u[1]*(gam2-gam);
    rhs[arsize]=u[arsize]*(1.0+gam)+u[arsize-1]*(gam+gam2);
    for (j=1; j<=arsize; j++) {
      tmp = lft[j]*mtmp[j-1];
      mtmp[j] = 1.0/(mid[j] - rt[j-1]*tmp);
      rhs[j] -= rhs[j-1]*tmp;
    }
    u[arsize] = rhs[arsize]*mtmp[arsize];
    for (j=arsize-1; j>=0; j--) {
      rhs[j] -= rhs[j+1]*rt[j]*mtmp[j+1];
      u[j] = rhs[j]*mtmp[j];
    }
    if (t<TAU_B) {               // curr pop size = Nb (Ne*OMEGA_B)
      u[newfreq]+=(1.0-tnew)*OMEGA_B*Ne*delt*THETA/delx;
      u[newfreq+1]+=tnew*OMEGA_B*Ne*delt*THETA/delx;
    } else if (t<TAU_B+TAU_B2) { // curr pop size = Nb2 (Ne*OMEGA_B2)
      u[newfreq2]+=(1.0-tnew2)*OMEGA_B2*Ne*delt*THETA/delx;
      u[newfreq2+1]+=tnew2*OMEGA_B2*Ne*delt*THETA/delx;
    } else {                     // curr pop size = Ne
      u[1]+=Ne*delt*THETA/delx;
    } //else
  } //for
  lasttime = t;
}

//qromo does the integration of the DoFE.
double qromo1(double a, double b, double (*choose)(double, double, int)) {
   int j;
   double ss,dss,h[JMAXP+1],s[JMAXP];
   h[1] = 1.0;
   for (j=1; j<=JMAX; j++) {
//      s[j] = midpnt(a,b,j);    // integrate over gamma in (a,b)
      s[j] = (*choose)(a,b,j);
      if (j >= K) {
         polint(&h[j-K],&s[j-K],K,0.0,&ss,&dss);
//         printf("%f %f\n",ss,dss);   // COMMENT
	 if (fabs(dss) <= EPS*fabs(ss)) {
            #ifdef DEBUG
            if (fabs(dss) > max_eps) max_eps = fabs(dss);
            #endif //DEBUG
            return ss;
	 }
	 if (fabs(dss) <= EPS*EPS) return ss;
      }
      h[j+1]=h[j]/9.0;
   }
//   #ifdef DEBUG
   printf("Too many steps routing qromo1 (%f,%f): dss=%lf, EPS=%lf, ss=%lf\n",a,b,dss,EPS,ss);
//   #endif
//   exit(1);
   max_eps = fabs(dss);
   return ss;
//   return expectedsfs1(0.0);   // REDUCES TO SCOTT'S VERSION
}


// calculates fixed differences between two species according
// to Sawyer & Hartl (1992) eqn 13.
double Dval(double _tdiv, double _theta, double _gamma) {
   double retval = 0.5*_tdiv*_theta;
   if (_gamma < -1.0e-7) return retval*2.0*_gamma*exp(2.0*_gamma)/(exp(2.0*_gamma)-1.0);
   else if (_gamma > 1.0e-7) return retval*2.0*_gamma/(1.0-exp(-2.0*_gamma));
   else return retval;     // neutral case
}

// expectedsfs1 --- searches to find expected site freq spectrum
//   under model parameters for GAMMA=g in matrix esfs. If found
//   it returns site freq for count=currj weighted by likelihood
//   of GAMMA=g; otherwise, it calculates site freq spectrum for
//   all count in [1,samplesize], stores in esfs and returns the
//   value for count=currj.
//
int iguess = 0;
double expectedsfs1(double g) {   // TODO: what if iter > maxiter???

   if (currj == 0) return weight(g);
   if (OUTGROUP) {
     printf("OUTGROUP: GAMMA = %f, j = %d\n", g, currj);
     int i;
     double *P=NULL;
     double t,q;
     u[0]=0; u[arsize]=0;

     for(i=1; i<arsize; i++) {
       q=i*delx;
       if (g < -1.0e-7) {                   // negative selection
         u[i]=THETA*OMEGA*(exp(2.0*g*OMEGA*q)-exp(2.0*g*OMEGA))/(q*(1.0-q)*(1.0-exp(2.0*g*OMEGA)));
       } else if (g > 1.0e-7) {             // positive selection
         u[i]=THETA*OMEGA*(1.0-exp(-2.0*g*OMEGA*(1.0-q)))/(q*(1.0-q)*(1.0-exp(-2.0*g*OMEGA)));
       } else u[i] = THETA*OMEGA/q;         // neutral case
     }
     return (weight(g)*PRFHazardSC(1,u));
//     return (weight(GAMMA)*(PRFHazardSC(1,u)+Dval(tdiv,THETA,GAMMA)));
   } //if OUTGROUP
   if ((currj > 0)&&(currj <= samplesize)) {
      GAMMA = g;
      int i=0;        // efficient (usually O(1)) search
      if ((iguess<iter)&&(fabs(esfs[iguess][0] - g) < 1.0e-7)) {
         i=iguess;
         #ifdef DEBUG
            iright++;
         #endif //DEBUG
      } else {
         #ifdef DEBUG
	    iwrong++;
         #endif //DEBUG
         while ((i<iter)&&(fabs(esfs[i][0] - g) >= 1.0e-7)) i++;
      }
      iguess=i;       // assume next query will be for same esfs column
//      if (distrib < 4)
        iguess++;     // assume next query will be for next esfs column
      if (i>=iter) {
//         printf("expectedsfs1: %f %f %f %f %f %f\n",OMEGA,TAU,OMEGA_B,TAU_B,OMEGA_B2,TAU_B2);  COMMENT
	 //if(iter%100==0)
//         printf("expectedsfs1 (eval %d): GAMMA = %f, j = %d\n", iter, g, currj);  COMMENT
         int j;
         //Main loop over generations
         esfs[iter][0] = GAMMA;
         tridag();
         for(j=1; j <= samplesize; j++) esfs[iter][j] = PRFHazardSC(j, u);
//       for(j=1; j <= samplesize; j++) printf("PRFHaz(%d) = %f\n",j,PRFHazardSC(j, u));
         esfs[iter][samplesize] += u[arsize];
         if (fixed == 2) esfs[iter][samplesize] += Dval(tdiv-TAU,THETA,GAMMA*OMEGA);
         else esfs[iter][samplesize] += Dval(2.0*tdiv-TAU,THETA,GAMMA*OMEGA);
//         printf("THETA=%f, PRFHaz = %f, u[ss] = %f, Dval = %f\n",THETA,PRFHazardSC(samplesize,u),u[arsize],Dval(tdiv-0.5*TAU,THETA,GAMMA*OMEGA));
//         printf("esfs(i=n,GAM=%f) = %f\n", GAMMA, esfs[iter][samplesize]);
         i=iter;
         iter++;
      }
      return (weight(GAMMA)*esfs[i][currj]);
   }
   printf("Error: currj not in range (currj=%d)!\n",currj);
   exit(1);
   return -1.0;
}


void polint(double xa[], double ya[], int n, double x, double *y, double *dy) {

   int i,m,ns=1;
   double den,dif,dift,ho,hp,w;
   double *c,*d;
   dif=fabs(x-xa[1]);
   c=dvector(1,n);
   d=dvector(1,n);
   for (i=1;i<=n;i++) {
      if ( (dift=fabs(x-xa[i])) < dif) {
         ns=i;
         dif=dift;
      }
      c[i]=ya[i];
      d[i]=ya[i];
   }
   *y=ya[ns--];
   for (m=1;m<n;m++) {
      for (i=1;i<=n-m;i++) {
         ho=xa[i]-x;
         hp=xa[i+m]-x;
         w=c[i+1]-d[i];
         if ( (den=ho-hp) == 0.0) nrerror("Error in routine polint");
         den=w/den;
         d[i]=hp*den;
         c[i]=ho*den;
      }
      *y += (*dy=(2*ns < (n-m) ? c[ns+1] : d[ns--]));
   }
//   printf("polint: %lf %lf (%d %lf)\n",*y,*dy,n,x);   // COMMENT
   free_dvector(d,1,n);
   free_dvector(c,1,n);
}

// based on qromo: Romberg integration on open interval
double PRFHazardSC(int count, double *dist) {
//   printf("\nPRFHazardSC(%d,dist):\n",count);
//   printf("dist[%f]=%f\tdist[%f]=%f\n",1.0/arsize,dist[1],1.0-1.0/arsize,dist[arsize-1]);
   int j;
   double sss,dss,h[JMAXP+1],s[JMAXP];
   h[1]=1.0;
   for (j=1;j<=JMAX;j++) {
      s[j]=midpnt0(0.,1.,j,count,dist);  // integrate over x in (0,1)
      if (j >= K) {
         polint(&h[j-K],&s[j-K],K,0.0,&sss,&dss);
         if ((fabs(dss) <= EPS*fabs(sss))||(dss <= EPS*EPS)) {
//           printf("sss=%f : j=%d, eps=%f\n",sss,j,fabs(dss));
           return sss;
         }
      }
      h[j+1]=h[j]/9.0;
//      printf("j=%d,JMAX=%d): eps=%e\n",j,JMAX,fabs(sss));
   }
   printf("j=%d,JMAX=%d): integral = %e, eps = %e\n",j,JMAX,sss, fabs(dss));
   printf("PRFHazardSC - Too many steps in routing qromo");
//   exit(1);
   return sss;
}

// NR, p142
double midpnt0(double a, double b, int n, int count, double *dist) {
   double x,tnm,sum,del,ddel;
   static double s;
   static int it;
   int j;
   if (n == 1) {
      it = 1;
      return (s=(b-a)*mrate(0.5*(a+b), count,dist));
   } else {
      tnm = it;
      del = (b-a)/(3.0*tnm);
      ddel = del + del;
      x=a+0.5*del;
      sum = 0.0;
      for (j=1;j<=it;j++) {
         sum += mrate(x, count,dist);
         x += ddel;
         sum += mrate(x, count,dist);
         x += del;
      }
      it *= 3;
      s=(s+(b-a)*sum/tnm)/3.0;
      return s;
   } //else
} //midpnt0

// NR, p142
double midpnt(double a, double b, int n) {
   double x,tnm,sum,del,ddel;
   static double s2;
   int it2,j;
   if (n==1) {
      return (s2=(b-a)*expectedsfs1(0.5*(a+b)));
   } else {
      for(it2=1,j=1;j<n-1;j++) it2*=3;
      tnm=it2;
      del=(b-a)/(3.0*tnm);
      ddel=del+del;
      x=a+0.5*del;
      sum=0.0;
      for (j=1;j<=it2;j++) {
         sum += expectedsfs1(x);
	 x += ddel;
	 sum += expectedsfs1(x);
	 x += del;
      }
   s2=(s2+(b-a)*sum/tnm)/3.0;
   return s2;
   }
}

#define FUNC(x) ((*expectedsfs1)(1.0/(x))/((x)*(x)))
double midinf(double aa, double bb, int n) {
   double x,tnm,sum,del,ddel,b,a;
   static double s2;
   int it2,j;
   b=1.0/aa;
   a=1.0/bb;
   if (n==1) {
      return (s2=(b-a)*FUNC(0.5*(a+b)));
   } else {
      for(it2=1,j=1;j<n-1;j++) it2*=3;
      tnm=it2;
      del=(b-a)/(3.0*tnm);
      ddel=del+del;
      x=a+0.5*del;
      sum=0.0;
      for (j=1;j<=it2;j++) {
	 sum += FUNC(x);
	 x += ddel;
	 sum += FUNC(x);
	 x += del;
      }
   s2=(s2+(b-a)*sum/tnm)/3.0;
   return s2;
   }
}

#define FUNCexpN(x) ((*expectedsfs1)(log(x))/(x))    // aa is -inf
double midexpN(double aa, double bb, int n) {
   double x,tnm,sum,del,ddel,b,a;
   static double s2;
   int it2,j;
   b=exp(bb);
   a=0.0;
   if (n==1) {
      return (s2=(b-a)*FUNCexpN(0.5*(a+b)));
   } else {
      for(it2=1,j=1;j<n-1;j++) it2*=3;
      tnm=it2;
      del=(b-a)/(3.0*tnm);
      ddel=del+del;
      x=a+0.5*del;
      sum=0.0;
      for (j=1;j<=it2;j++) {
	 sum += FUNCexpN(x);
	 x += ddel;
	 sum += FUNCexpN(x);
	 x += del;
      }
   s2=(s2+(b-a)*sum/tnm)/3.0;
   return s2;
   }
}

#define FUNCexp(x) ((*expectedsfs1)(-log(x))/(x))     // bb is inf
double midexp(double aa, double bb, int n) {
   double x,tnm,sum,del,ddel,b,a;
   static double s2;
   int it2,j;
   b=exp(-aa);
   a=0.0;
   if (n==1) {
      return (s2=(b-a)*FUNCexp(0.5*(a+b)));
   } else {
      for(it2=1,j=1;j<n-1;j++) it2*=3;
      tnm=it2;
      del=(b-a)/(3.0*tnm);
      ddel=del+del;
      x=a+0.5*del;
      sum=0.0;
      for (j=1;j<=it2;j++) {
	 sum += FUNCexp(x);
	 x += ddel;
	 sum += FUNCexp(x);
	 x += del;
      }
   s2=(s2+(b-a)*sum/tnm)/3.0;
   return s2;
   }
}

double weight(double _gamma) {
  switch ( distrib ) {
  case 0: if (fabs(_gamma) < 1.0e-7) return 1.0;            // neutral
          else return 0.0;
          break;
  case 1: if (fabs(_gamma - param[0]) < 1.0e-7) return 1.0; // 1-pt
          else return 0.0;
          break;
  case 2: if (fabs(_gamma) < 1.0e-7) return param[0];       // 2-pt
          else if (fabs(_gamma - param[1]) < 1.0e-7) return (1.0-param[0]);
          else return 0.0;
          break;
  case 3: if ((param[0] < 0.0)||(param[1] < 0.0)||(param[0]+param[1]>1.0)) {
             printf("ERROR: pneutral = %lf, pneg = %lf\n", param[0], param[1]);
             exit(1);
          }
          if (fabs(_gamma) < 1.0e-7) return param[0];       // 3-pt
          else if (fabs(_gamma - param[2]) < 1.0e-7) return (param[1]);
          else if (fabs(_gamma - param[3]) < 1.0e-7)
                 return (1.0 - param[0] - param[1]);
          else return 0.0;
          break;
//  case 4: return param[2]*pdf_norm(_gamma, param[0], param[1]);
  case 4: return pdf_norm(_gamma, param[0], param[1]);
          break;
  case 5: if ((param[2] < 0.0)||(param[2] > 1.0)) {
            printf("Error: weight out of range (param[2] = %f)\n",param[2]);
            exit(1);
          }
          else return param[2]*pdf_norm(_gamma, param[0], param[1]);
          break;
/*  case 6: if (_gamma>0.0) return 1.0;
       // printf("Error: weight out of range (gam=%f)\n",_gamma);
          else if (_gamma>-1.0e-8) return 1.0;
          return param[2]*pdf_gamma(-_gamma, param[0], param[1]);
          break;  */
  case 6: if (_gamma>-1.0e-8) return 1.0;
          if (num_dims == 2) return pdf_gamma(-_gamma, param[0], param[1]);
          return (1.0-param[2])*pdf_gamma(-_gamma, param[0], param[1]);
          break;
  case 7: if ((param[2]<0.0)||(param[2]>1.0)) {
            printf("Error: weight out of range (param[2]=%f)\n",param[2]);
            exit(1);
          } else
              if (_gamma < 0.0)
                return (1.0-param[2])*pdf_gamma(-_gamma, param[0], param[1]);
              else if (_gamma > 0.0)
                return param[2]*pdf_gamma(_gamma, param[0], param[1]);
              else return 1.0;
          break;
  case 8: if (_gamma>-1.0e-8) return 1.0;
          return pdf_gamma(param[2]-_gamma, param[0], param[1]);
          break;
  case 9: if (_gamma > -1.0e-8) return 1.0;
          return pdf_lognorm(-_gamma, param[0], param[1], 0.0);
          break;
  case 10: return pdf_lognorm(-_gamma, param[0], param[1], param[2]);
          break;
  case 11: if ((param[4] < 0.0)||(param[4] > 1.0)) {
            printf("Error: weight out of range (param[4]=%f)\n",param[4]);
            exit(1);
          }
          return (param[4]*pdf_norm(_gamma,param[0],param[2]) +
                          (1.0-param[4])*pdf_norm(_gamma,param[1],param[3]));
          break;
  case 12: if (param[0] < 0.0) {
            printf("Error: weight out of range (param[0]=%f)\n",param[0]);
            exit(1);
          } else if (_gamma > 0.0) {
            printf("Error: weight out of range (gamma=%f)\n",gamma);
            exit(1);
          } else if (_gamma > -1.0e-8) {
//            printf("_gamma = %lf, weight = %lf\n",_gamma,param[1]);
            return param[1];
          } else {
            return ((1.0-param[1]-param[2])*param[0]*exp(_gamma*param[0]));
          }
//          } else return (param[0]*exp(_gamma*param[0]));
          break;
  case 13: if (_gamma == -500.0) return param[0];
          else if (_gamma == -50.0) return param[1];
          else if (_gamma == -5.0) return param[2];
          else if (_gamma >= 0.0) {
            if (param[3] == 0.0) return (1.0-param[0]-param[1]-param[2]);
            else return (2.0*(param[3]-_gamma)*(1.0-param[2]-param[1]-param[0])/(param[3]*param[3]));
          } else printf("Error in weight()!\n");
          break;
  case 14: if ((param[2] < 0.0)||(param[2] > 1.0)) {
            printf("Error: weight out of range (param[2] = %f)\n",param[2]);
          }
          if (fabs(_gamma + param[3]) < 1.0e-7) return (1.0 - param[2]);
          else return param[2]*pdf_norm(_gamma, param[0], param[1]);
          break;
  case 16: if (fabs(_gamma)>1e-8) {
             printf("Error: gamma must equal zero (gamm=%f)!\n",_gamma);
             exit(1);
           }
           if ((param[0]<0.0)||(param[0]>1.0)) {
             printf("Error: pneu out of range (param0=%f)!\n",param[0]);
             exit(1);
           }
           return param[0];
           break;
  case 17: if (fabs(_gamma)<1e-8) return param[0];
           if (fabs(_gamma-param[1])<1e-8) return param[2];
           printf("Warning: gamma not 0 or param[1] (gam=%f)\n",_gamma);
           return 0.0;
           break;
  case 22: if (param[1]+param[2] > 1.0) {
             printf("Error: pleth + ppos > 1! (pleth = %lf, ppos = %lf)\n",param[1],param[2]);
             exit(1);
           }
           if (_gamma<-1e-8) return ((1.0-param[1]-param[2])*param[0]*exp(_gamma*param[0]));
           else if (_gamma>1e-8) return (param[2]*param[0]*exp(-_gamma*param[0]));
           else return ((1.0-param[1]-param[2])*param[0]*exp((-1.0e-8)*param[0]));
           break;
  } //switch
  printf("ERROR in weight: distrib %d not recognized!\n",distrib);
  exit(1);
}

double mrate(double x, int count, double *dist) {
     return (pow(x,count)*pow(1.0-x, samplesize-count)*bico(samplesize, count)*transient(x, dist));
}

