001    package calhoun.analysis.crf.solver;
002    
003    import java.io.BufferedWriter;
004    import java.util.Arrays;
005    import java.util.Collections;
006    import java.util.List;
007    
008    import org.apache.commons.logging.Log;
009    import org.apache.commons.logging.LogFactory;
010    
011    import calhoun.analysis.crf.ModelManager;
012    import calhoun.analysis.crf.io.InputSequence;
013    import calhoun.analysis.crf.io.TrainingSequence;
014    import calhoun.util.Assert;
015    import calhoun.util.DenseBooleanMatrix2D;
016    import calhoun.util.DenseIntMatrix2D;
017    import calhoun.util.FileUtil;
018    
019    /** basic functionality common to most cache processors.  */
020    public abstract class CacheProcessorBasic implements CacheProcessor {
021            private static final Log log = LogFactory.getLog(CacheProcessorBasic.class);
022            
023            String trainingFile = null;
024            BufferedWriter trainingWriter = null;
025            protected short[] maxStateLengths;
026    
027            protected List<? extends TrainingSequence<?>> data;
028            protected ModelManager fm;
029            
030            protected SolverSetup modelInfo;
031            protected FeatureEvaluation[] evals;
032            protected LengthFeatureEvaluation[][] lengthEvals;
033    
034            protected double[] featureSums; // sums of each feature through entire training data set; length of array is number of features in fm.
035            protected double[][] seqFeatureSums; // sums of each feature through each sequence; length of array is number of features in fm.
036    
037            protected void computeFeatureSums() {
038                    trainingWriter = FileUtil.safeOpen(trainingFile);
039                    
040                    int numFeatures = fm.getNumFeatures();
041                    seqFeatureSums = new double[data.size()][numFeatures];
042                    featureSums = new double[numFeatures];
043                    double[] lastSegmentFeatureSums = new double[numFeatures];
044                    
045                    for (int seqnum=0; seqnum<data.size(); seqnum++) {
046                            TrainingSequence seq = data.get(seqnum);
047    
048                            int prevSegmentState = -1;
049                            int segmentLength = 1;
050                            int lastStart = 0;
051    
052                            for (int pos=0; pos < seq.length(); pos++) {
053                                    int state = seq.getY(pos);
054    
055                                    evaluatePosition(seqnum, pos);
056                                    sumFeatures(featureSums, seqFeatureSums[seqnum], evals[state]);
057                                    
058                                    if (pos>0) {
059                                            int pot = modelInfo.nStates + modelInfo.transitionIndex.getQuick(seq.getY(pos-1),state);
060                                            sumFeatures(featureSums, seqFeatureSums[seqnum], evals[pot]);
061                                    }
062                                    
063                                    // If we are at the end of a segment, evaluate length features
064                                    if(pos == seq.length() - 1 || state != seq.getY(pos+1)) {
065                                            evaluateSegmentsEndingAt(seqnum, pos);
066                                            
067                                            // Find the corrent node
068                                            int nodeIndex=0;
069                                            for(; nodeIndex<modelInfo.statesWithLookback.length; ++nodeIndex) {
070                                                    if(modelInfo.statesWithLookback[nodeIndex].state == state)
071                                                            break;
072                                            }
073                                            if(nodeIndex != modelInfo.statesWithLookback.length) {
074                                                    // Find the correct lookback entry
075                                                    int lbIndex = 0;
076                                                    while(lengthEvals[nodeIndex][lbIndex].lookback != segmentLength-1) {
077                                                            if(lengthEvals[nodeIndex][lbIndex].lookback == -1) {
078                                                                    Assert.a(false, "Lookback not listed. State: ", modelInfo.statesWithLookback[nodeIndex].state, " Seq: ", seqnum, " Pos: ", pos, " Len: ", segmentLength, " # Lookbacks: ", lbIndex);
079                                                            }
080                                                            ++lbIndex;
081                                                    }
082                                                    
083                                                    sumFeatures(featureSums, seqFeatureSums[seqnum], lengthEvals[nodeIndex][lbIndex].nodeEval);
084                                                    if(prevSegmentState != -1 && lengthEvals[nodeIndex][lbIndex].edgeEvals != null) {
085                                                            throw new UnsupportedOperationException("ComputeFeatureSums doesn't handle explicit length edge evals yet.");
086                                                    }
087                                            }
088                                            prevSegmentState = state;
089                                            lastStart = pos-segmentLength+1;
090                                            segmentLength = 1;
091                                    }
092                                    else {
093                                            segmentLength += 1;
094                                    }
095                                    
096                                    // Write out the sums for this segment
097                                    if(trainingWriter != null) {
098                                            if(pos == seq.length() - 1 || state != seq.getY(pos+1)) {
099                                                    for (int i=0; i<numFeatures; i++) {
100                                                            FileUtil.safeWrite(trainingWriter, String.format("Seq: %d Seg: %d-%d State: %d Feat: %d Val: %f\n", seqnum, lastStart, pos, state, i, featureSums[i] - lastSegmentFeatureSums[i]));
101                                                            lastSegmentFeatureSums[i] = featureSums[i];
102                                                    }
103                                            }
104                                    }
105                            }
106                    }
107                    if(log.isDebugEnabled()) {
108                            log.debug("We just computed the feature sums on the training data.  The feature sums are (id,name,sum)");
109                            for (int j=0; j<numFeatures; j++) {
110                                    log.debug("(" + j + ","+fm.getFeatureName(j) + "," + featureSums[j] + ")");
111                            }
112                    }
113                    FileUtil.safeClose(trainingWriter);
114            }
115    
116            void sumFeatures(double[] featureSums, double[] seqFeatureSum, FeatureEvaluation eval) {
117                    int i=0;
118                    while(eval.index[i] != -1) {
119                            featureSums[eval.index[i]] += eval.value[i];
120                            seqFeatureSum[eval.index[i]] += eval.value[i];
121                            ++i;
122                    }
123            }
124            
125            public void setInputData(ModelManager fm, InputSequence<?> seq) {
126                    // Create a dummy set of hidden states all -1
127                    int[] dummyHiddenStates = new int[seq.length()];
128                    Arrays.fill(dummyHiddenStates, Integer.MIN_VALUE);
129                    setTrainingData(fm, Collections.singletonList(new TrainingSequence<Object>(seq, dummyHiddenStates)));
130            }
131    
132            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
133                    this.fm = fm;
134                    this.data = data;
135            }
136    
137            public double[] getFeatureSums() {
138                    return featureSums;
139            }
140            
141            public double[][] getSequenceFeatureSums() {
142                    return seqFeatureSums;
143            }
144            
145            void basicInit(boolean allPaths) {
146                    initSequenceInfo();
147                    initTransitions(allPaths);
148                    evals = FeatureEvaluation.create(modelInfo.nPotentials, Math.max(5, modelInfo.nFeatures));
149            }
150    
151            public SolverSetup getSolverSetup() {
152                    return modelInfo;
153            }
154    
155            public FeatureEvaluation[] getFeatureEvaluations() {
156                    return evals;
157            }
158            
159            public LengthFeatureEvaluation[][] getLengthFeatureEvaluations() {
160                    return lengthEvals;
161            }
162            
163            protected void initSequenceInfo() {
164                    // Initial basic parameters
165                    modelInfo = new SolverSetup();
166                    modelInfo.nFeatures = fm.getNumFeatures();
167                    modelInfo.nStates = fm.getNumStates();
168    
169                    // Information about the input data
170                    modelInfo.nSeqs = data.size();
171                    modelInfo.seqOffsets = new int[modelInfo.nSeqs+1];
172                    modelInfo.seqOffsets[0] = 0;
173                    modelInfo.longestSeq = 0;
174                    modelInfo.totalPositions = 0;
175                    for (int i = 0; i < modelInfo.nSeqs; ++i) {
176                            TrainingSequence seq = data.get(i);
177                            int seqLen = seq.length();
178                            modelInfo.longestSeq = Math.max(seqLen, modelInfo.longestSeq);
179                            modelInfo.seqOffsets[i + 1] = modelInfo.seqOffsets[i] + seqLen;
180                            modelInfo.totalPositions += seqLen;
181                    }
182            }
183            
184            protected boolean isSemiMarkovState(int state) {
185                    return maxStateLengths == null ? false : maxStateLengths[state] > 1;
186            }
187    
188            protected void initTransitions(boolean allPaths) {
189                    // Initial the transition information
190                    modelInfo.transitionIndex = new DenseIntMatrix2D(modelInfo.nStates, modelInfo.nStates);
191                    modelInfo.transitionIndex.assign(-1);
192                    modelInfo.selfTransitions = new int[modelInfo.nStates];
193                    Arrays.fill(modelInfo.selfTransitions, -1);
194                    DenseBooleanMatrix2D transitions = fm.getLegalTransitions();
195                    if(transitions == null || allPaths) {
196                            transitions = new DenseBooleanMatrix2D(modelInfo.nStates, modelInfo.nStates);
197                            transitions.assign(true);
198                    }
199    
200                    short count = 0;
201                    for(short i = 0; i<modelInfo.nStates; ++i) {
202                            for(short j = 0; j<modelInfo.nStates; ++j) {
203                                    if(transitions.getQuick(i, j) || (i == j && isSemiMarkovState(i)))
204                                            count++;
205                            }
206                    }
207                    modelInfo.nTransitions = count;
208                    modelInfo.nPotentials = modelInfo.nStates + modelInfo.nTransitions;
209                    modelInfo.orderedPotentials = new short[modelInfo.nPotentials];
210                    modelInfo.transitionFrom = new short[modelInfo.nTransitions];
211                    modelInfo.transitionTo = new short[modelInfo.nTransitions];
212                    count = 0;
213                    int orderedCount = 0;
214                    for(short i = 0; i<modelInfo.nStates; ++i) {
215                            modelInfo.orderedPotentials[orderedCount] = i;
216                            orderedCount++;
217                            for(short j = 0; j<modelInfo.nStates; ++j) {
218                                    if(transitions.getQuick(j, i) || (i == j && isSemiMarkovState(i))) {
219                                            modelInfo.orderedPotentials[orderedCount] = (short) (modelInfo.nStates+count);
220                                            orderedCount++;
221                                            if(i == j)
222                                                    modelInfo.selfTransitions[i] = count;
223                                            modelInfo.transitionIndex.setQuick(j, i, count);
224                                            modelInfo.transitionFrom[count] = j;
225                                            modelInfo.transitionTo[count] = i;
226                                            ++count;
227                                    }
228                            }
229                    }
230                    Assert.a(count == modelInfo.nTransitions);
231            }
232    
233            public String getTrainingFile() {
234                    return trainingFile;
235            }
236    
237            public void setTrainingFile(String trainingFile) {
238                    this.trainingFile = trainingFile;
239            }
240    
241            /**
242             * @return the data
243             */
244            public List<? extends TrainingSequence<?>> getData() {
245                    return data;
246            }
247    }