001 package calhoun.analysis.crf.solver;
002
003 import java.io.BufferedWriter;
004 import java.util.Arrays;
005 import java.util.Collections;
006 import java.util.List;
007
008 import org.apache.commons.logging.Log;
009 import org.apache.commons.logging.LogFactory;
010
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.io.InputSequence;
013 import calhoun.analysis.crf.io.TrainingSequence;
014 import calhoun.util.Assert;
015 import calhoun.util.DenseBooleanMatrix2D;
016 import calhoun.util.DenseIntMatrix2D;
017 import calhoun.util.FileUtil;
018
019 /** basic functionality common to most cache processors. */
020 public abstract class CacheProcessorBasic implements CacheProcessor {
021 private static final Log log = LogFactory.getLog(CacheProcessorBasic.class);
022
023 String trainingFile = null;
024 BufferedWriter trainingWriter = null;
025 protected short[] maxStateLengths;
026
027 protected List<? extends TrainingSequence<?>> data;
028 protected ModelManager fm;
029
030 protected SolverSetup modelInfo;
031 protected FeatureEvaluation[] evals;
032 protected LengthFeatureEvaluation[][] lengthEvals;
033
034 protected double[] featureSums; // sums of each feature through entire training data set; length of array is number of features in fm.
035 protected double[][] seqFeatureSums; // sums of each feature through each sequence; length of array is number of features in fm.
036
037 protected void computeFeatureSums() {
038 trainingWriter = FileUtil.safeOpen(trainingFile);
039
040 int numFeatures = fm.getNumFeatures();
041 seqFeatureSums = new double[data.size()][numFeatures];
042 featureSums = new double[numFeatures];
043 double[] lastSegmentFeatureSums = new double[numFeatures];
044
045 for (int seqnum=0; seqnum<data.size(); seqnum++) {
046 TrainingSequence seq = data.get(seqnum);
047
048 int prevSegmentState = -1;
049 int segmentLength = 1;
050 int lastStart = 0;
051
052 for (int pos=0; pos < seq.length(); pos++) {
053 int state = seq.getY(pos);
054
055 evaluatePosition(seqnum, pos);
056 sumFeatures(featureSums, seqFeatureSums[seqnum], evals[state]);
057
058 if (pos>0) {
059 int pot = modelInfo.nStates + modelInfo.transitionIndex.getQuick(seq.getY(pos-1),state);
060 sumFeatures(featureSums, seqFeatureSums[seqnum], evals[pot]);
061 }
062
063 // If we are at the end of a segment, evaluate length features
064 if(pos == seq.length() - 1 || state != seq.getY(pos+1)) {
065 evaluateSegmentsEndingAt(seqnum, pos);
066
067 // Find the corrent node
068 int nodeIndex=0;
069 for(; nodeIndex<modelInfo.statesWithLookback.length; ++nodeIndex) {
070 if(modelInfo.statesWithLookback[nodeIndex].state == state)
071 break;
072 }
073 if(nodeIndex != modelInfo.statesWithLookback.length) {
074 // Find the correct lookback entry
075 int lbIndex = 0;
076 while(lengthEvals[nodeIndex][lbIndex].lookback != segmentLength-1) {
077 if(lengthEvals[nodeIndex][lbIndex].lookback == -1) {
078 Assert.a(false, "Lookback not listed. State: ", modelInfo.statesWithLookback[nodeIndex].state, " Seq: ", seqnum, " Pos: ", pos, " Len: ", segmentLength, " # Lookbacks: ", lbIndex);
079 }
080 ++lbIndex;
081 }
082
083 sumFeatures(featureSums, seqFeatureSums[seqnum], lengthEvals[nodeIndex][lbIndex].nodeEval);
084 if(prevSegmentState != -1 && lengthEvals[nodeIndex][lbIndex].edgeEvals != null) {
085 throw new UnsupportedOperationException("ComputeFeatureSums doesn't handle explicit length edge evals yet.");
086 }
087 }
088 prevSegmentState = state;
089 lastStart = pos-segmentLength+1;
090 segmentLength = 1;
091 }
092 else {
093 segmentLength += 1;
094 }
095
096 // Write out the sums for this segment
097 if(trainingWriter != null) {
098 if(pos == seq.length() - 1 || state != seq.getY(pos+1)) {
099 for (int i=0; i<numFeatures; i++) {
100 FileUtil.safeWrite(trainingWriter, String.format("Seq: %d Seg: %d-%d State: %d Feat: %d Val: %f\n", seqnum, lastStart, pos, state, i, featureSums[i] - lastSegmentFeatureSums[i]));
101 lastSegmentFeatureSums[i] = featureSums[i];
102 }
103 }
104 }
105 }
106 }
107 if(log.isDebugEnabled()) {
108 log.debug("We just computed the feature sums on the training data. The feature sums are (id,name,sum)");
109 for (int j=0; j<numFeatures; j++) {
110 log.debug("(" + j + ","+fm.getFeatureName(j) + "," + featureSums[j] + ")");
111 }
112 }
113 FileUtil.safeClose(trainingWriter);
114 }
115
116 void sumFeatures(double[] featureSums, double[] seqFeatureSum, FeatureEvaluation eval) {
117 int i=0;
118 while(eval.index[i] != -1) {
119 featureSums[eval.index[i]] += eval.value[i];
120 seqFeatureSum[eval.index[i]] += eval.value[i];
121 ++i;
122 }
123 }
124
125 public void setInputData(ModelManager fm, InputSequence<?> seq) {
126 // Create a dummy set of hidden states all -1
127 int[] dummyHiddenStates = new int[seq.length()];
128 Arrays.fill(dummyHiddenStates, Integer.MIN_VALUE);
129 setTrainingData(fm, Collections.singletonList(new TrainingSequence<Object>(seq, dummyHiddenStates)));
130 }
131
132 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
133 this.fm = fm;
134 this.data = data;
135 }
136
137 public double[] getFeatureSums() {
138 return featureSums;
139 }
140
141 public double[][] getSequenceFeatureSums() {
142 return seqFeatureSums;
143 }
144
145 void basicInit(boolean allPaths) {
146 initSequenceInfo();
147 initTransitions(allPaths);
148 evals = FeatureEvaluation.create(modelInfo.nPotentials, Math.max(5, modelInfo.nFeatures));
149 }
150
151 public SolverSetup getSolverSetup() {
152 return modelInfo;
153 }
154
155 public FeatureEvaluation[] getFeatureEvaluations() {
156 return evals;
157 }
158
159 public LengthFeatureEvaluation[][] getLengthFeatureEvaluations() {
160 return lengthEvals;
161 }
162
163 protected void initSequenceInfo() {
164 // Initial basic parameters
165 modelInfo = new SolverSetup();
166 modelInfo.nFeatures = fm.getNumFeatures();
167 modelInfo.nStates = fm.getNumStates();
168
169 // Information about the input data
170 modelInfo.nSeqs = data.size();
171 modelInfo.seqOffsets = new int[modelInfo.nSeqs+1];
172 modelInfo.seqOffsets[0] = 0;
173 modelInfo.longestSeq = 0;
174 modelInfo.totalPositions = 0;
175 for (int i = 0; i < modelInfo.nSeqs; ++i) {
176 TrainingSequence seq = data.get(i);
177 int seqLen = seq.length();
178 modelInfo.longestSeq = Math.max(seqLen, modelInfo.longestSeq);
179 modelInfo.seqOffsets[i + 1] = modelInfo.seqOffsets[i] + seqLen;
180 modelInfo.totalPositions += seqLen;
181 }
182 }
183
184 protected boolean isSemiMarkovState(int state) {
185 return maxStateLengths == null ? false : maxStateLengths[state] > 1;
186 }
187
188 protected void initTransitions(boolean allPaths) {
189 // Initial the transition information
190 modelInfo.transitionIndex = new DenseIntMatrix2D(modelInfo.nStates, modelInfo.nStates);
191 modelInfo.transitionIndex.assign(-1);
192 modelInfo.selfTransitions = new int[modelInfo.nStates];
193 Arrays.fill(modelInfo.selfTransitions, -1);
194 DenseBooleanMatrix2D transitions = fm.getLegalTransitions();
195 if(transitions == null || allPaths) {
196 transitions = new DenseBooleanMatrix2D(modelInfo.nStates, modelInfo.nStates);
197 transitions.assign(true);
198 }
199
200 short count = 0;
201 for(short i = 0; i<modelInfo.nStates; ++i) {
202 for(short j = 0; j<modelInfo.nStates; ++j) {
203 if(transitions.getQuick(i, j) || (i == j && isSemiMarkovState(i)))
204 count++;
205 }
206 }
207 modelInfo.nTransitions = count;
208 modelInfo.nPotentials = modelInfo.nStates + modelInfo.nTransitions;
209 modelInfo.orderedPotentials = new short[modelInfo.nPotentials];
210 modelInfo.transitionFrom = new short[modelInfo.nTransitions];
211 modelInfo.transitionTo = new short[modelInfo.nTransitions];
212 count = 0;
213 int orderedCount = 0;
214 for(short i = 0; i<modelInfo.nStates; ++i) {
215 modelInfo.orderedPotentials[orderedCount] = i;
216 orderedCount++;
217 for(short j = 0; j<modelInfo.nStates; ++j) {
218 if(transitions.getQuick(j, i) || (i == j && isSemiMarkovState(i))) {
219 modelInfo.orderedPotentials[orderedCount] = (short) (modelInfo.nStates+count);
220 orderedCount++;
221 if(i == j)
222 modelInfo.selfTransitions[i] = count;
223 modelInfo.transitionIndex.setQuick(j, i, count);
224 modelInfo.transitionFrom[count] = j;
225 modelInfo.transitionTo[count] = i;
226 ++count;
227 }
228 }
229 }
230 Assert.a(count == modelInfo.nTransitions);
231 }
232
233 public String getTrainingFile() {
234 return trainingFile;
235 }
236
237 public void setTrainingFile(String trainingFile) {
238 this.trainingFile = trainingFile;
239 }
240
241 /**
242 * @return the data
243 */
244 public List<? extends TrainingSequence<?>> getData() {
245 return data;
246 }
247 }