001    package calhoun.analysis.crf.features.supporting;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.statistics.BasicStats;
010    import calhoun.util.Assert;
011    
012    public class MaxentMotifModel  {
013            private static final long serialVersionUID = -7659288739348604129L;
014            private static final Log log = LogFactory.getLog(MaxentMotifModel.class);
015            boolean debug = log.isDebugEnabled();
016            
017            // This class is written to have a single publicly static available function
018            // "trainMaxentDistributionUsingAllPairwiseConstraints", and everything else
019            // is private.
020            //
021            // This motifs are assumed to all be of same length and at each position take
022            // a value between 0-3.
023            
024            public static double[] trainMaxentDistributionUsingAllPairwiseConstraints(List<int[]> motifExamples, int span, int nIter, double pseudocount) {
025                    
026                    int nExamples = motifExamples.size();
027                    if (nExamples == 0) { log.warn("Warning -- attempting to train a maxent distribution without any examples; the flat distribution will eventually be returned."); }
028                    for (int j=0; j<nExamples; j++) {
029                            int[] motif = motifExamples.get(j);
030                            Assert.a(motif.length == span);
031                            for (int k=0; k<span; k++) {
032                                    Assert.a( (motif[k]>=0) && (motif[k]<4) );
033                            }
034                    }
035                    
036                    List<Constraint> motifConstraints = makeAllPairwiseConstraints(motifExamples,span, pseudocount);
037                    
038                    log.debug("The numebr of motifConstraints is " + motifConstraints.size());
039                    
040                    double[] ret = trainMaxentDistribution(motifConstraints,span, nIter);
041                    
042                    return ret;
043            }
044            
045            
046            private static List<Constraint> makeAllPairwiseConstraints(List<int[]> motifExamples, int motifLen, double pseudocount) {
047                    int nMotif = motifExamples.size();
048                    Assert.a(motifLen > 1);
049                    for (int i=1; i<nMotif; i++) { Assert.a(motifExamples.get(i).length == motifLen); }
050                    
051                    List<Constraint> ret = new ArrayList<Constraint>();
052                    
053                    for (int pos1 = 0; pos1<(motifLen-1); pos1++) {
054                            for (int pos2=pos1+1; pos2<motifLen; pos2++) {
055                                    double[] counts = new double[16];
056                                    for (int i=0; i<16; i++) { counts[i] = 0.0; }
057                                    
058                                    for (int j=0; j<nMotif; j++) {
059                                            int hash = 4*motifExamples.get(j)[pos1] + motifExamples.get(j)[pos2];
060                                            counts[hash] += 1.0;
061                                    }
062                                    
063                                    double total = 0.0;
064                                    for (int i=0; i<16; i++) { total += counts[i]; }
065                                    for (int i=0; i<16; i++) { counts[i] /= total; }
066                            
067                                    Constraint c = new Constraint(motifLen, pos1,pos2, motifExamples, pseudocount);
068                                    ret.add( c );
069                                    
070                            }
071                    }
072                    
073                    return ret;
074            }
075    
076            private static double[] trainMaxentDistribution( List<Constraint> motifConstraints  ,  int span, int nIter ) {
077                    int     hSize = 1; for (int j=0; j<span; j++) { hSize *= 4; }
078                    
079                    double[] ret = new double[hSize];
080                    for (int j=0; j<hSize; j++) {
081                            ret[j] = 1/((double) hSize);
082                    }
083                    
084                    int nCon = motifConstraints.size();
085                    if ( nCon == 0 ) {
086                            log.warn("Warning -- no constraints, returning maximum entropy distribution");
087                            return ret;
088                    }
089                    
090                    for (int iter=0; iter<nIter; iter++) {
091                            int cNum = (int) (nCon*Math.random());
092                            log.debug("Enforcing constrain number " + cNum + " which is " + motifConstraints.get(cNum).stringSummary() );   
093                            ret = motifConstraints.get(cNum).enforce(ret);
094                    }
095                    
096                    return ret;             
097            }
098            
099            private static class Constraint {
100                    int span=-1;
101                    int pos1;
102                    int pos2;
103                    double[] prob;
104    
105                    int msize,size;
106                    static int[] sixteen, newsixteen, many, newmany;
107                    
108                    public Constraint(int newspan, int pos1, int pos2, List<int[]> motifExamples, double pseudocount ) {
109    
110                            if (span != newspan) {
111                                    span = newspan;
112                                    msize=1; for (int j=0; j<(span-2); j++) { msize *= 4; }
113                                    size = 16*msize;
114                            
115                                    sixteen = new int[16];
116                                    newsixteen = new int[16];
117                                    many = new int[msize];
118                                    newmany = new int[msize];
119                            }       
120                            
121                            this.pos1 = pos1;
122                            this.pos2 = pos2;
123                                            
124                            Assert.a(0<=pos1);
125                            Assert.a(pos1<pos2);
126                            Assert.a(pos2<span);
127                            Assert.a(2<=span);
128                            
129                            train(motifExamples, pseudocount);
130                    }
131    
132                    
133                    private void train(List<int[]> motifExamples, double pseudocount) {
134                            
135                            int nMotif = motifExamples.size();
136                            for (int i=0; i<nMotif; i++) { Assert.a(motifExamples.get(i).length == span); }
137                            
138                            prob = new double[16];
139                            for (int i=0; i<16; i++) { prob[i] = pseudocount; }
140                            
141                            for (int j=0; j<nMotif; j++) {
142                                    int hash = 4*motifExamples.get(j)[pos1] + motifExamples.get(j)[pos2];
143                                    prob[hash] += 1.0;
144                            }
145                            
146                            double total = 0.0;
147                            for (int i=0; i<16; i++) { total += prob[i]; }
148                            for (int i=0; i<16; i++) { prob[i] /= total; }
149                            
150                            double sum = BasicStats.sumDoubleArray(prob);
151                            Assert.a( (sum > 0.999) && (sum < 1.001) );
152                    }
153    
154                    
155                    public double[] enforce(double[] pp) {
156    
157                            
158                            double[] qq = pp;
159                            
160                            
161                            /* We wish first to represent the range from 0 to size-1 as the cross-sum of two
162                             * arrays of integers: one representing the 16 probabilities constrained by the constraint,
163                             * and one representing the values of the big joint that are all within the same cell
164                             * of the constraint.
165                             */
166                            
167                            many[0]=0;        int nM=1;
168                            sixteen[0]=0;     int nS=1;
169                            
170                            for (int pos=0; pos<span; pos++) {
171                                    if ( (pos==pos1) || (pos==pos2) ) {
172                                            //List<Integer> newsixteen = new ArrayList<Integer>(); 
173                                            for (int j=0; j<nS; j++) {
174                                                    int temp = 4*sixteen[j];
175                                                    newsixteen[4*j]   = temp;
176                                                    newsixteen[4*j+1] = temp+1;
177                                                    newsixteen[4*j+2] = temp+2;
178                                                    newsixteen[4*j+3] = temp+3;
179                                            }
180                                            nS*=4;
181                                            for (int j=0; j<nS; j++) { sixteen[j] = newsixteen[j]; }
182                                            //List<Integer> newmany = new ArrayList<Integer>(); 
183                                            for (int j=0; j<nM; j++) {
184                                                    many[j] *= 4;
185                                            }
186                                    } else {
187                                            //List<Integer> newsixteen = new ArrayList<Integer>(); 
188                                            for (int j=0; j<nS; j++) {
189                                                    sixteen[j] *= 4;
190                                            }
191                                            //List<Integer> newmany = new ArrayList<Integer>(); 
192                                            for (int j=0; j<nM; j++) {
193                                                    int temp = 4*many[j];
194                                                    newmany[4*j] = temp;
195                                                    newmany[4*j+1] = temp+1;
196                                                    newmany[4*j+2] = temp+2;
197                                                    newmany[4*j+3] = temp+3;
198                                            }
199                                            nM *= 4;
200                                            for (int j=0; j<nM; j++) { many[j] = newmany[j]; }
201                                    }       
202                            }
203                            
204                            
205                            Assert.a(nS == 16);
206                            Assert.a(nM == msize);
207                            
208                            
209                            double changeneeded = 0.00;
210                            for (int z=0; z<16; z++) {
211                                    int base = sixteen[z];
212                                    double total = 0.0;
213                                    for (int t=0; t<nM; t++) {
214                                            total += qq[base + many[t]];
215                                    }
216                                    if ( prob[z]>0 ) {
217                                            Assert.a(total>0);
218                                            changeneeded += Math.abs(prob[z] - total);
219                                            double mult = prob[z]/total;
220                                            for (int t=0; t<nM; t++) {
221                                                    int temp = base + many[t]; 
222                                                    qq[temp] = qq[temp]*mult;
223                                            }
224                                    }
225                            }
226    
227                            System.out.println("amount of change needed to enforce constraint was " + changeneeded );
228                            
229                            return qq;
230                    }
231    
232                    public String stringSummary() {
233                            String ret = "constraint_pos1=" + pos1 + "_pos2="+pos2 + "_span="+span;
234                            return ret;
235                    }
236            }
237            
238    }
239