001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.util.ArrayList;
004    import java.util.Arrays;
005    import java.util.Collections;
006    import java.util.HashMap;
007    import java.util.List;
008    import java.util.Map;
009    
010    import org.apache.commons.logging.Log;
011    import org.apache.commons.logging.LogFactory;
012    
013    import calhoun.analysis.crf.AbstractFeatureManager;
014    import calhoun.analysis.crf.FeatureList;
015    import calhoun.analysis.crf.FeatureManagerNode;
016    import calhoun.analysis.crf.ModelManager;
017    import calhoun.analysis.crf.io.InputSequence;
018    import calhoun.analysis.crf.io.TrainingSequence;
019    import calhoun.util.Assert;
020    
021    /** trains on the data and then evaluates to P(state | label) for given Kmers.  Used for historical reasons.  Emission markov generally does better.  
022     */
023    public class KmerFeatures extends AbstractFeatureManager<Character> implements FeatureManagerNode<Character> {
024            private static final long serialVersionUID = 5959560033335736926L;
025            private static final Log log = LogFactory.getLog(KmerFeatures.class);
026    
027            public enum Cardinality {
028                    SINGLE, PER_STATE, PER_KMER
029            };
030    
031            static final int DEFAULT_RARE_THRESHOLD = 25;
032            int startIx;
033            ModelManager model;
034            int nStates;
035            List<int[]> kmerDefs;
036            Map<String, double[]>[] counts;
037            Map<String, Integer>[] kmerIds;
038            Cardinality cardinality = Cardinality.PER_STATE;
039            int rareThreshold = DEFAULT_RARE_THRESHOLD;
040    
041            public KmerFeatures(Cardinality cardinality) {
042                    this.cardinality = cardinality;
043                    this.kmerDefs = Collections.singletonList(new int[] { 0 });
044            }
045    
046            public KmerFeatures() {
047                    this.kmerDefs = Collections.singletonList(new int[] { 0 });
048            }
049    
050            public KmerFeatures(List<int[]> kmerDefs) {
051                    this.kmerDefs = kmerDefs;
052            }
053    
054            public KmerFeatures(List<int[]> kmerDefs, Cardinality cardinality) {
055                    this.cardinality = cardinality;
056                    this.kmerDefs = kmerDefs;
057            }
058    
059            public void setKmerDefinitions(List<List<Integer>> defs) {
060                    kmerDefs = new ArrayList<int[]>();
061                    for(List<Integer> def : defs) {
062                            int[] kmer = new int[def.size()];
063                            kmerDefs.add(kmer);
064                            for(int i=0; i<def.size(); ++i) {
065                                    kmer[i] = def.get(i);
066                            }
067                    }
068            }
069            
070            public void setRareThreshold(int threshold) {
071                    Assert.a(model == null, "Can't set threshold after training.");
072                    rareThreshold = threshold;
073            }
074    
075            public int getNumFeatures() {
076                    switch (cardinality) {
077                    case SINGLE:
078                            return kmerDefs.size();
079                    case PER_STATE:
080                            return kmerDefs.size() * model.getNumStates();
081                    default:
082                            Assert.a(cardinality == Cardinality.PER_KMER);
083                            int num = 0;
084                            for (Map m : counts) {
085                                    num += m.size();
086                            }
087                            return num;
088                    }
089            }
090    
091            public String getFeatureName(int featureIndex) {
092                    int raw = featureIndex - startIx;
093                    int label = raw / kmerDefs.size();
094                    int kmer = raw % kmerDefs.size();
095                    String name = kmerName(kmer);
096                    String val = "";
097                    if (cardinality == Cardinality.PER_STATE) {
098                            val = model.getStateName(label) + ".";
099                    } else if (cardinality == Cardinality.PER_KMER) {
100                            val = "(InsertKMer)";
101                    }
102                    return "Kmer." + val + name;
103            }
104    
105            transient InputSequence<? extends Character> lastSeq;
106            transient int lastPos;
107            String[] kmers;
108            double[][] vals;
109    
110            public void evaluateNode(InputSequence<? extends Character> seq, int pos, int state, FeatureList result) {
111                    if (seq != lastSeq || pos != lastPos) {
112                            lastSeq = seq;
113                            lastPos = pos;
114                            for (int j = 0; j < kmerDefs.size(); ++j) {
115                                    String kmer = getKmer(seq, pos, kmerDefs.get(j));
116                                    vals[j] = counts[j].get(kmer);
117                                    if (cardinality == Cardinality.PER_KMER) {
118                                            kmers[j] = kmer;
119                                    }
120                            }
121                    }
122                    for (int j = 0; j < kmerDefs.size(); ++j) {
123                            double[] kmerVals = vals[j];
124                            if (kmerVals == null || kmerVals[state] == 0.0)
125                                    continue;
126                            int index;
127                            if (cardinality == Cardinality.SINGLE) {
128                                    index = startIx + j;
129                            } else if (cardinality == Cardinality.PER_STATE) {
130                                    index = startIx + j + kmerDefs.size() * state;
131                            } else {
132                                    Assert.a(cardinality == Cardinality.PER_KMER);
133                                    Integer id = kmerIds[j].get(kmers[j]);
134                                    if (id == null)
135                                            continue;
136                                    index = id;
137                            }
138                            result.addFeature(index, kmerVals[state]);
139                    }
140            }
141    
142            public String getKmer(InputSequence<? extends Character> seq, int pos, int[] def) {
143                    StringBuffer buf = new StringBuffer(def.length);
144                    for (int i = 0; i < def.length; ++i) {
145                            int loc = pos + def[i];
146                            if (loc < 0 || loc >= seq.length())
147                                    return null;
148                            buf.append((Character) seq.getX(loc));
149                    }
150                    return buf.toString();
151            }
152    
153            /** Returns a string representation of a given kmer definition */
154            public String kmerName(int index) {
155                    int[] kmer = kmerDefs.get(index);
156                    StringBuffer ret = new StringBuffer();
157                    for (int i = 0; i < kmer.length; ++i) {
158                            if (i != 0) {
159                                    ret.append(".");
160                            }
161                            ret.append(kmer[i]);
162                    }
163                    return ret.toString();
164            }
165    
166            /** Returns an individual entry from the counts list. */
167            public double getKmerProb(int kmerIndex, String kmer, int label) {
168                    return counts[kmerIndex].get(kmer)[label];
169            }
170    
171            /** Computes the P(label | kmer) for each kmer across all of the training data. These will used as features values. */
172            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
173                    startIx = startingIndex;
174                    model = modelInfo;
175                    nStates = model.getNumStates();
176                    vals = new double[kmerDefs.size()][];
177                    // Initialize an array of hash maps for each kmer
178                    counts = new HashMap[kmerDefs.size()];
179                    for (int i = 0; i < counts.length; ++i) {
180                            counts[i] = new HashMap();
181                    }
182                    if (cardinality == Cardinality.PER_KMER) {
183                            kmerIds = new HashMap[kmerDefs.size()];
184                            for (int i = 0; i < kmerIds.length; ++i) {
185                                    kmerIds[i] = new HashMap();
186                            }
187                    }
188                    // Count the occurances of each kmer in each state
189                    for (TrainingSequence<? extends Character> seq : data) {
190                            int len = seq.length();
191                            for (int pos = 0; pos < len; ++pos) {
192                                    int nKmers = kmerDefs.size();
193                                    for (int i = 0; i < nKmers; ++i) {
194                                            String kmer = getKmer(seq, pos, kmerDefs.get(i));
195                                            if (kmer != null) {
196                                                    double[] val = (double[]) counts[i].get(kmer);
197                                                    if (val == null) {
198                                                            val = new double[nStates];
199                                                            Arrays.fill(val, 1);
200                                                            counts[i].put(kmer, val);
201                                                    }
202                                                    int state = seq.getY(pos);
203                                                    val[state] += 1;
204                                            }
205                                    }
206                            }
207                    }
208                    int kmerId = startIx;
209                    // Now compute probabilities for all kmers with at least KMER_THRESHOLD appearances
210                    for (int j = 0; j < kmerDefs.size(); ++j) {
211                            int kmerCounts = 0;
212                            for (Map.Entry<String, double[]> val : counts[j].entrySet()) {
213                                    double[] cnts = val.getValue();
214                                    double total = 0.0f;
215                                    for (int i = 0; i < cnts.length; ++i) {
216                                            total += cnts[i];
217                                    }
218                                    if (total >= rareThreshold) {
219                                            kmerCounts += 1;
220                                            if (cardinality == Cardinality.PER_KMER) {
221                                                    kmerIds[j].put(val.getKey(), kmerId);
222                                                    kmerId++;
223                                            }
224                                    }
225                                    for (int i = 0; i < cnts.length; ++i) {
226                                            // Just convert to P(label | kmer)
227                                            if (total >= rareThreshold) {
228                                                    cnts[i] = Math.log(cnts[i] / total);
229                                                    Assert.a(!Double.isNaN(cnts[i]));
230                                                    Assert.a(!Double.isInfinite(cnts[i]));
231                                            } else {
232                                                    cnts[i] = 0;
233                                            }
234                                    }
235                            }
236                            log.info(kmerCounts + " kmers of " + kmerName(j));
237                    }
238            }
239    }