001 package calhoun.analysis.crf.solver;
002
003 import java.util.ArrayList;
004 import java.util.List;
005
006 import calhoun.analysis.crf.ModelManager;
007 import calhoun.analysis.crf.io.InputSequence;
008 import calhoun.analysis.crf.io.TrainingSequence;
009 import calhoun.util.Assert;
010 import calhoun.util.DenseIntMatrix2D;
011
012 /** interface to implementations of feature caches. The <code>CacheProcessor</code> is the
013 * interface between most of the objective functions and the actual feature managers.
014 */
015 public interface CacheProcessor {
016
017 /** This class holds the feature evaluations for a given position, or position/length combination.
018 * There are two arrays, one to hold the feature indicies, and the other to hold the values. Each
019 * array is two-dimensional. The first dimension is fixed length and contains one entry for each potential
020 * (node or edge) in the model. The potentials are in model order, with each node followed by all transitions into that node.
021 * The second dimension is variable length and contains a information on feature evaluations for that potential.
022 * In the second dimension, each list of evaluations is terminated by a -1 in the index column.
023 */
024 public static class FeatureEvaluation {
025 FeatureEvaluation(int nFeatures) {
026 index = new short[nFeatures];
027 value = new float[nFeatures];
028 index[0] = -1;
029 }
030 public short[] index;
031 public float[] value;
032
033 public static FeatureEvaluation[] create(int nPotentials, int featureSize) {
034 FeatureEvaluation[] potential = new FeatureEvaluation[nPotentials];
035 for(int i=0; i<potential.length; ++i) {
036 potential[i] = new FeatureEvaluation(featureSize);
037 }
038 return potential;
039 }
040 }
041
042 public static class LengthFeatureEvaluation {
043 public LengthFeatureEvaluation(int nFeatures) {
044 nodeEval = new FeatureEvaluation(nFeatures);
045 }
046
047 public short lookback = -1;
048 public FeatureEvaluation nodeEval;
049 public FeatureEvaluation[] edgeEvals;
050
051 public static LengthFeatureEvaluation[][] create(StatePotentials[] statePotentials, int nLookbacks, int nFeatures) {
052 LengthFeatureEvaluation[][] lookbacks = new LengthFeatureEvaluation[statePotentials.length][nLookbacks];
053 for(int i=0; i<statePotentials.length; ++i) {
054 for(int j=0; j<nLookbacks; ++j) {
055 lookbacks[i][j] = new LengthFeatureEvaluation(nFeatures);
056 }
057 }
058 return lookbacks;
059 }
060 }
061
062 public static class SolverSetup {
063 public int nFeatures;
064 public int nStates;
065 public int nPotentials;
066 public int nTransitions;
067
068 /// Index into the starts array of the first position of each sequence.
069 public int[] seqOffsets;
070
071 /// Number of sequences in the training data set
072 public int nSeqs;
073
074 /// The length of the longest single sequence in the data set.
075 public int longestSeq;
076
077 public short[] transitionFrom;
078 public short[] transitionTo;
079 public short[] orderedPotentials;
080 public int totalPositions;
081 public int[] selfTransitions;
082
083 public DenseIntMatrix2D transitionIndex;
084 // For each array, contains a list of the possible destinate
085 public byte[][] exitTransitions;
086
087 public short[] maxStateLengths;
088
089 public short maxLookback = 1;
090 public StatePotentials[] statesWithLookback;
091 public StatePotentials[] statesWithoutLookback;
092
093 public void setup(ModelManager fm, List<? extends TrainingSequence<?>> data, boolean allPaths, short[] maxStateLengths2, boolean ignoreSemiMarkovSelfTransitions) {
094 this.maxStateLengths = maxStateLengths2;
095 initStatesWithLookback(maxStateLengths);
096
097 Assert.a(statesWithLookback != null);
098 Assert.a(statesWithoutLookback != null);
099 }
100
101 protected boolean allowSelf(int state) {
102 return maxStateLengths[state] > 1;
103 }
104
105
106 /** Creates an array of StatePotential objects given the maximum lookback for each state. */
107 public void initStatesWithLookback(short[] maxStateLengths) {
108 List<CacheProcessor.StatePotentials> with = new ArrayList<CacheProcessor.StatePotentials>();
109 List<CacheProcessor.StatePotentials> without = new ArrayList<CacheProcessor.StatePotentials>();
110 exitTransitions = new byte[maxStateLengths.length][];
111
112 for(byte i=0; i< maxStateLengths.length; ++i) {
113 // Set up the statePotentials structure
114 CacheProcessor.StatePotentials p = new CacheProcessor.StatePotentials();
115 p.state = i;
116 List<Byte> pots = new ArrayList<Byte>();
117 List<Byte> exits = new ArrayList<Byte>();
118 boolean length = maxStateLengths[i] > 1;
119 for(int prevState = 0; prevState < maxStateLengths.length; ++prevState) {
120 // Get list of transitions into this state
121 int trans = transitionIndex.getQuick(prevState, i);
122 if(trans != -1) {
123 pots.add((byte) (trans + nStates));
124 }
125
126 // Get list of transitions out this state
127 trans = transitionIndex.getQuick(i, prevState);
128 if(trans != -1 && prevState != i) {
129 exits.add((byte) (trans + nStates));
130 }
131 }
132 p.potentials = toByteArray(pots);
133 if(length) {
134 with.add(p);
135 }
136 else {
137 without.add(p);
138 }
139 exitTransitions[i] = toByteArray(exits);
140 }
141 statesWithLookback = with.toArray(new CacheProcessor.StatePotentials[with.size()]);
142 statesWithoutLookback = without.toArray(new CacheProcessor.StatePotentials[without.size()]);
143 }
144
145 private byte[] toByteArray(List<Byte> list) {
146 byte[] ret = new byte[list.size()];
147 for(int i=0; i<ret.length; ++i) {
148 ret[i] = list.get(i);
149 }
150 return ret;
151 }
152 }
153
154 public static final class StatePotentials {
155 public byte state;
156 public byte[] potentials;
157 }
158
159 public void setInputData(ModelManager fm, InputSequence<?> data);
160
161 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data);
162 public List<? extends TrainingSequence<?>> getData();
163
164 public double[] getFeatureSums();
165 public double[][] getSequenceFeatureSums();
166 public boolean[] getInvalidTransitions();
167
168 public SolverSetup getSolverSetup();
169
170 public FeatureEvaluation[] getFeatureEvaluations();
171
172 public LengthFeatureEvaluation[][] getLengthFeatureEvaluations();
173
174 public void evaluatePosition(int seq, int pos);
175
176 public void evaluateSegmentsEndingAt(int seq, int pos);
177 }