001 package calhoun.analysis.crf.solver;
002
003 import java.util.List;
004
005 import calhoun.analysis.crf.FeatureList;
006 import calhoun.analysis.crf.ModelManager;
007 import calhoun.analysis.crf.io.InputSequence;
008 import calhoun.analysis.crf.io.TrainingSequence;
009 import calhoun.util.Assert;
010
011 /** a dummy cache processor that fulfills the interface but doesn't cache. It always retrieves values
012 * by calling the {@link calhoun.analysis.crf.FeatureManager} <code>evlaute</code> functions. <b>Does not work with
013 * NodeBoundary features.</b>
014 */
015 public class NoCachingCacheProcessor extends CacheProcessorBasic {
016 //private static final Log log = LogFactory.getLog(NoCachingCacheProcessor.class);
017
018 boolean allPaths;
019
020 /// Cached values of the sums of each feature value through the whole training set.
021 double[] featureSumsLocal;
022 boolean[] invalidTransitions;
023 DirectFeatureList result;
024
025 class DirectFeatureList implements FeatureList {
026 FeatureEvaluation evals1;
027 public int position;
028 boolean valid;
029
030 public DirectFeatureList() {
031 }
032
033 public void addFeature(int index, double val) {
034 evals1.index[position] = (short) index;
035 evals1.value[position++] = (float) val;
036 }
037
038 /** Returns the invalid flag. */
039 public boolean isValid() {
040 return valid;
041 }
042
043 /** Invalidates results. */
044 public void invalidate() {
045 valid = false;
046 }
047 }
048
049 /** true if all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
050 * @return true if all paths are to be examined
051 */
052 public boolean isAllPaths() {
053 return allPaths;
054 }
055
056 /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
057 * @param allPaths allPath true if all paths are to be examined
058 */
059 public void setAllPaths(boolean allPaths) {
060 this.allPaths = allPaths;
061 }
062
063 @Override
064 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
065 super.setTrainingData(fm, data);
066 basicInit(allPaths);
067 result = new DirectFeatureList();
068
069 invalidTransitions = new boolean[modelInfo.nPotentials*modelInfo.totalPositions];
070 calcFeatureSums();
071 }
072
073 public boolean[] getInvalidTransitions() {
074 return invalidTransitions;
075 }
076
077 @Override
078 public double[] getFeatureSums() {
079 return featureSumsLocal;
080 }
081
082 void calcFeatureSums() {
083 featureSumsLocal = new double[modelInfo.nFeatures];
084 for(int seqNum = 0; seqNum < modelInfo.nSeqs; ++seqNum) {
085 TrainingSequence train = (TrainingSequence) data.get(seqNum);
086 int len = train.length();
087 int previousState = -1;
088 for(int pos = 0; pos < len; ++pos) {
089 evaluatePosition(seqNum, pos);
090 int state = train.getY(pos);
091 int i = 0;
092 FeatureEvaluation potEval = evals[state];
093 int index = potEval.index[i];
094 while(index != -1) {
095 float val = potEval.value[i];
096 Assert.a(!Float.isNaN(val));
097 featureSumsLocal[index] += val;
098 index = potEval.index[++i];
099 }
100
101 for(int trans = 0; trans < modelInfo.nTransitions; ++trans) {
102 if(modelInfo.transitionTo[trans] == state && modelInfo.transitionFrom[trans] == previousState) {
103 i = 0;
104 int pot = modelInfo.nStates+trans;
105 potEval = evals[pot];
106 index = potEval.index[i];
107 while(index != -1) {
108 float val = potEval.value[i];
109 Assert.a(!Float.isNaN(val));
110 featureSumsLocal[index] += val;
111 index = potEval.index[++i];
112 }
113 }
114 }
115 previousState = state;
116 }
117 }
118 }
119
120 public void evaluatePosition(int seqNum, int pos) {
121 InputSequence seq = data.get(seqNum);
122 for(int pot=0; pot<modelInfo.nPotentials; ++pot) {
123 result.evals1 = evals[pot];
124 result.position = 0;
125 result.valid = true;
126 if(pot < modelInfo.nStates) {
127 fm.evaluateNode(seq, pos, pot, result);
128 }
129 else {
130 if(pos == 0) {
131 result.evals1.index[0] = -1;
132 continue;
133 }
134 int trans = pot - modelInfo.nStates;
135 fm.evaluateEdge(seq, pos, modelInfo.transitionFrom[trans], modelInfo.transitionTo[trans], result);
136 }
137 if(result.isValid()) {
138 result.evals1.index[result.position] = -1;
139 }
140 else {
141 result.evals1.value[0] = Float.NaN;
142 result.evals1.index[0] = Short.MIN_VALUE;
143 result.evals1.index[1] = -1;
144 }
145 }
146 }
147
148 public void evaluateSegmentsEndingAt(int seq, int pos) {
149 throw new UnsupportedOperationException();
150 }
151
152 }
153