001    package calhoun.analysis.crf.solver;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import calhoun.analysis.crf.ModelManager;
007    import calhoun.analysis.crf.io.InputSequence;
008    import calhoun.analysis.crf.io.TrainingSequence;
009    import calhoun.util.Assert;
010    import calhoun.util.DenseIntMatrix2D;
011    
012    /** interface to implementations of feature caches.  The <code>CacheProcessor</code> is the 
013     * interface between most of the objective functions and the actual feature managers.
014     */  
015    public interface CacheProcessor {
016            
017            /** This class holds the feature evaluations for a given position, or position/length combination.
018             * There are two arrays, one to hold the feature indicies, and the other to hold the values.  Each
019             * array is two-dimensional.  The first dimension is fixed length and contains one entry for each potential 
020             * (node or edge) in the model.  The potentials are in model order, with each node followed by all transitions into that node.
021             * The second dimension is variable length and contains a information on feature evaluations for that potential.
022             * In the second dimension, each list of evaluations is terminated by a -1 in the index column.  
023             */ 
024            public static class FeatureEvaluation {
025                    FeatureEvaluation(int nFeatures) {
026                            index = new short[nFeatures];
027                            value = new float[nFeatures];
028                            index[0] = -1;
029                    }
030                    public short[] index;
031                    public float[] value;
032    
033                    public static FeatureEvaluation[] create(int nPotentials, int featureSize) {
034                            FeatureEvaluation[] potential = new FeatureEvaluation[nPotentials];
035                            for(int i=0; i<potential.length; ++i) {
036                                    potential[i] = new FeatureEvaluation(featureSize);
037                            }
038                            return potential;
039                    }
040            }
041            
042            public static class LengthFeatureEvaluation {
043                    public LengthFeatureEvaluation(int nFeatures) {
044                            nodeEval = new FeatureEvaluation(nFeatures);
045                    }
046                    
047                    public short lookback = -1;
048                    public FeatureEvaluation nodeEval;
049                    public FeatureEvaluation[] edgeEvals;
050                    
051                    public static LengthFeatureEvaluation[][] create(StatePotentials[] statePotentials, int nLookbacks, int nFeatures) {
052                            LengthFeatureEvaluation[][] lookbacks = new LengthFeatureEvaluation[statePotentials.length][nLookbacks];
053                            for(int i=0; i<statePotentials.length; ++i) {
054                                    for(int j=0; j<nLookbacks; ++j) {
055                                            lookbacks[i][j] = new LengthFeatureEvaluation(nFeatures);
056                                    }
057                            }
058                            return lookbacks;
059                    }
060            }
061            
062            public static class SolverSetup {
063                    public int nFeatures;
064                    public int nStates;
065                    public int nPotentials;
066                    public int nTransitions;
067                    
068                    /// Index into the starts array of the first position of each sequence.
069                    public int[] seqOffsets;
070    
071                    /// Number of sequences in the training data set
072                    public int nSeqs; 
073    
074                    /// The length of the longest single sequence in the data set.
075                    public int longestSeq; 
076    
077                    public short[] transitionFrom;
078                    public short[] transitionTo;
079                    public short[] orderedPotentials;
080                    public int totalPositions;
081                    public int[] selfTransitions;
082    
083                    public DenseIntMatrix2D transitionIndex;
084                    // For each array, contains a list of the possible destinate
085                    public byte[][] exitTransitions;
086                    
087                    public short[] maxStateLengths;
088    
089                    public short maxLookback = 1;
090                    public StatePotentials[] statesWithLookback;
091                    public StatePotentials[] statesWithoutLookback;
092    
093                    public void setup(ModelManager fm, List<? extends TrainingSequence<?>> data, boolean allPaths, short[] maxStateLengths2, boolean ignoreSemiMarkovSelfTransitions) {
094                            this.maxStateLengths = maxStateLengths2;
095                            initStatesWithLookback(maxStateLengths);
096                            
097                            Assert.a(statesWithLookback != null);
098                            Assert.a(statesWithoutLookback != null);                        
099                    }
100                    
101                    protected boolean allowSelf(int state) {
102                            return maxStateLengths[state] > 1;
103                    }
104                    
105    
106                    /** Creates an array of StatePotential objects given the maximum lookback for each state. */
107                    public void initStatesWithLookback(short[] maxStateLengths) {
108                            List<CacheProcessor.StatePotentials> with = new ArrayList<CacheProcessor.StatePotentials>();
109                            List<CacheProcessor.StatePotentials> without = new ArrayList<CacheProcessor.StatePotentials>();
110                            exitTransitions = new byte[maxStateLengths.length][];
111    
112                            for(byte i=0; i< maxStateLengths.length; ++i) {
113                                    // Set up the statePotentials structure
114                                    CacheProcessor.StatePotentials p = new CacheProcessor.StatePotentials();
115                                    p.state = i;
116                                    List<Byte> pots = new ArrayList<Byte>();
117                                    List<Byte> exits = new ArrayList<Byte>();
118                                    boolean length = maxStateLengths[i] > 1;
119                                    for(int prevState = 0; prevState < maxStateLengths.length; ++prevState) {
120                                            // Get list of transitions into this state
121                                            int trans = transitionIndex.getQuick(prevState, i);
122                                            if(trans != -1) {
123                                                    pots.add((byte) (trans + nStates));
124                                            }
125    
126                                            // Get list of transitions out  this state
127                                            trans = transitionIndex.getQuick(i, prevState);
128                                            if(trans != -1 && prevState != i) {
129                                                    exits.add((byte) (trans + nStates));
130                                            }
131                                    }
132                                    p.potentials = toByteArray(pots);
133                                    if(length) {
134                                            with.add(p);
135                                    }
136                                    else {
137                                            without.add(p);
138                                    }
139                                    exitTransitions[i] = toByteArray(exits);
140                            }
141                            statesWithLookback = with.toArray(new CacheProcessor.StatePotentials[with.size()]);
142                            statesWithoutLookback = without.toArray(new CacheProcessor.StatePotentials[without.size()]);
143                    }
144    
145                    private byte[] toByteArray(List<Byte> list) {
146                            byte[] ret = new byte[list.size()];
147                            for(int i=0; i<ret.length; ++i) {
148                                    ret[i] = list.get(i);
149                            }
150                            return ret;
151                    }
152            }
153    
154            public static final class StatePotentials {
155                    public byte state;
156                    public byte[] potentials;
157            }
158    
159            public void setInputData(ModelManager fm, InputSequence<?> data);
160            
161            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data);
162            public List<? extends TrainingSequence<?>> getData();
163            
164            public double[] getFeatureSums();
165            public double[][] getSequenceFeatureSums();
166            public boolean[] getInvalidTransitions();
167            
168            public SolverSetup getSolverSetup();
169            
170            public FeatureEvaluation[] getFeatureEvaluations();
171            
172            public LengthFeatureEvaluation[][] getLengthFeatureEvaluations();
173            
174            public void evaluatePosition(int seq, int pos);
175    
176            public void evaluateSegmentsEndingAt(int seq, int pos);
177    }