001 package calhoun.analysis.crf.features.supporting;
002
003 import java.util.ArrayList;
004 import java.util.List;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.statistics.BasicStats;
010 import calhoun.util.Assert;
011
012 public class MaxentMotifModel {
013 private static final long serialVersionUID = -7659288739348604129L;
014 private static final Log log = LogFactory.getLog(MaxentMotifModel.class);
015 boolean debug = log.isDebugEnabled();
016
017 // This class is written to have a single publicly static available function
018 // "trainMaxentDistributionUsingAllPairwiseConstraints", and everything else
019 // is private.
020 //
021 // This motifs are assumed to all be of same length and at each position take
022 // a value between 0-3.
023
024 public static double[] trainMaxentDistributionUsingAllPairwiseConstraints(List<int[]> motifExamples, int span, int nIter, double pseudocount) {
025
026 int nExamples = motifExamples.size();
027 if (nExamples == 0) { log.warn("Warning -- attempting to train a maxent distribution without any examples; the flat distribution will eventually be returned."); }
028 for (int j=0; j<nExamples; j++) {
029 int[] motif = motifExamples.get(j);
030 Assert.a(motif.length == span);
031 for (int k=0; k<span; k++) {
032 Assert.a( (motif[k]>=0) && (motif[k]<4) );
033 }
034 }
035
036 List<Constraint> motifConstraints = makeAllPairwiseConstraints(motifExamples,span, pseudocount);
037
038 log.debug("The numebr of motifConstraints is " + motifConstraints.size());
039
040 double[] ret = trainMaxentDistribution(motifConstraints,span, nIter);
041
042 return ret;
043 }
044
045
046 private static List<Constraint> makeAllPairwiseConstraints(List<int[]> motifExamples, int motifLen, double pseudocount) {
047 int nMotif = motifExamples.size();
048 Assert.a(motifLen > 1);
049 for (int i=1; i<nMotif; i++) { Assert.a(motifExamples.get(i).length == motifLen); }
050
051 List<Constraint> ret = new ArrayList<Constraint>();
052
053 for (int pos1 = 0; pos1<(motifLen-1); pos1++) {
054 for (int pos2=pos1+1; pos2<motifLen; pos2++) {
055 double[] counts = new double[16];
056 for (int i=0; i<16; i++) { counts[i] = 0.0; }
057
058 for (int j=0; j<nMotif; j++) {
059 int hash = 4*motifExamples.get(j)[pos1] + motifExamples.get(j)[pos2];
060 counts[hash] += 1.0;
061 }
062
063 double total = 0.0;
064 for (int i=0; i<16; i++) { total += counts[i]; }
065 for (int i=0; i<16; i++) { counts[i] /= total; }
066
067 Constraint c = new Constraint(motifLen, pos1,pos2, motifExamples, pseudocount);
068 ret.add( c );
069
070 }
071 }
072
073 return ret;
074 }
075
076 private static double[] trainMaxentDistribution( List<Constraint> motifConstraints , int span, int nIter ) {
077 int hSize = 1; for (int j=0; j<span; j++) { hSize *= 4; }
078
079 double[] ret = new double[hSize];
080 for (int j=0; j<hSize; j++) {
081 ret[j] = 1/((double) hSize);
082 }
083
084 int nCon = motifConstraints.size();
085 if ( nCon == 0 ) {
086 log.warn("Warning -- no constraints, returning maximum entropy distribution");
087 return ret;
088 }
089
090 for (int iter=0; iter<nIter; iter++) {
091 int cNum = (int) (nCon*Math.random());
092 log.debug("Enforcing constrain number " + cNum + " which is " + motifConstraints.get(cNum).stringSummary() );
093 ret = motifConstraints.get(cNum).enforce(ret);
094 }
095
096 return ret;
097 }
098
099 private static class Constraint {
100 int span=-1;
101 int pos1;
102 int pos2;
103 double[] prob;
104
105 int msize,size;
106 static int[] sixteen, newsixteen, many, newmany;
107
108 public Constraint(int newspan, int pos1, int pos2, List<int[]> motifExamples, double pseudocount ) {
109
110 if (span != newspan) {
111 span = newspan;
112 msize=1; for (int j=0; j<(span-2); j++) { msize *= 4; }
113 size = 16*msize;
114
115 sixteen = new int[16];
116 newsixteen = new int[16];
117 many = new int[msize];
118 newmany = new int[msize];
119 }
120
121 this.pos1 = pos1;
122 this.pos2 = pos2;
123
124 Assert.a(0<=pos1);
125 Assert.a(pos1<pos2);
126 Assert.a(pos2<span);
127 Assert.a(2<=span);
128
129 train(motifExamples, pseudocount);
130 }
131
132
133 private void train(List<int[]> motifExamples, double pseudocount) {
134
135 int nMotif = motifExamples.size();
136 for (int i=0; i<nMotif; i++) { Assert.a(motifExamples.get(i).length == span); }
137
138 prob = new double[16];
139 for (int i=0; i<16; i++) { prob[i] = pseudocount; }
140
141 for (int j=0; j<nMotif; j++) {
142 int hash = 4*motifExamples.get(j)[pos1] + motifExamples.get(j)[pos2];
143 prob[hash] += 1.0;
144 }
145
146 double total = 0.0;
147 for (int i=0; i<16; i++) { total += prob[i]; }
148 for (int i=0; i<16; i++) { prob[i] /= total; }
149
150 double sum = BasicStats.sumDoubleArray(prob);
151 Assert.a( (sum > 0.999) && (sum < 1.001) );
152 }
153
154
155 public double[] enforce(double[] pp) {
156
157
158 double[] qq = pp;
159
160
161 /* We wish first to represent the range from 0 to size-1 as the cross-sum of two
162 * arrays of integers: one representing the 16 probabilities constrained by the constraint,
163 * and one representing the values of the big joint that are all within the same cell
164 * of the constraint.
165 */
166
167 many[0]=0; int nM=1;
168 sixteen[0]=0; int nS=1;
169
170 for (int pos=0; pos<span; pos++) {
171 if ( (pos==pos1) || (pos==pos2) ) {
172 //List<Integer> newsixteen = new ArrayList<Integer>();
173 for (int j=0; j<nS; j++) {
174 int temp = 4*sixteen[j];
175 newsixteen[4*j] = temp;
176 newsixteen[4*j+1] = temp+1;
177 newsixteen[4*j+2] = temp+2;
178 newsixteen[4*j+3] = temp+3;
179 }
180 nS*=4;
181 for (int j=0; j<nS; j++) { sixteen[j] = newsixteen[j]; }
182 //List<Integer> newmany = new ArrayList<Integer>();
183 for (int j=0; j<nM; j++) {
184 many[j] *= 4;
185 }
186 } else {
187 //List<Integer> newsixteen = new ArrayList<Integer>();
188 for (int j=0; j<nS; j++) {
189 sixteen[j] *= 4;
190 }
191 //List<Integer> newmany = new ArrayList<Integer>();
192 for (int j=0; j<nM; j++) {
193 int temp = 4*many[j];
194 newmany[4*j] = temp;
195 newmany[4*j+1] = temp+1;
196 newmany[4*j+2] = temp+2;
197 newmany[4*j+3] = temp+3;
198 }
199 nM *= 4;
200 for (int j=0; j<nM; j++) { many[j] = newmany[j]; }
201 }
202 }
203
204
205 Assert.a(nS == 16);
206 Assert.a(nM == msize);
207
208
209 double changeneeded = 0.00;
210 for (int z=0; z<16; z++) {
211 int base = sixteen[z];
212 double total = 0.0;
213 for (int t=0; t<nM; t++) {
214 total += qq[base + many[t]];
215 }
216 if ( prob[z]>0 ) {
217 Assert.a(total>0);
218 changeneeded += Math.abs(prob[z] - total);
219 double mult = prob[z]/total;
220 for (int t=0; t<nM; t++) {
221 int temp = base + many[t];
222 qq[temp] = qq[temp]*mult;
223 }
224 }
225 }
226
227 System.out.println("amount of change needed to enforce constraint was " + changeneeded );
228
229 return qq;
230 }
231
232 public String stringSummary() {
233 String ret = "constraint_pos1=" + pos1 + "_pos2="+pos2 + "_span="+span;
234 return ret;
235 }
236 }
237
238 }
239