001 package calhoun.analysis.crf.features.tricycle13;
002
003 import java.util.ArrayList;
004 import java.util.Arrays;
005 import java.util.Collections;
006 import java.util.HashMap;
007 import java.util.List;
008 import java.util.Map;
009
010 import org.apache.commons.logging.Log;
011 import org.apache.commons.logging.LogFactory;
012
013 import calhoun.analysis.crf.AbstractFeatureManager;
014 import calhoun.analysis.crf.FeatureList;
015 import calhoun.analysis.crf.FeatureManagerNode;
016 import calhoun.analysis.crf.ModelManager;
017 import calhoun.analysis.crf.io.InputSequence;
018 import calhoun.analysis.crf.io.TrainingSequence;
019 import calhoun.util.Assert;
020
021 /** trains on the data and then evaluates to P(state | label) for given Kmers. Used for historical reasons. Emission markov generally does better.
022 */
023 public class KmerFeatures extends AbstractFeatureManager<Character> implements FeatureManagerNode<Character> {
024 private static final long serialVersionUID = 5959560033335736926L;
025 private static final Log log = LogFactory.getLog(KmerFeatures.class);
026
027 public enum Cardinality {
028 SINGLE, PER_STATE, PER_KMER
029 };
030
031 static final int DEFAULT_RARE_THRESHOLD = 25;
032 int startIx;
033 ModelManager model;
034 int nStates;
035 List<int[]> kmerDefs;
036 Map<String, double[]>[] counts;
037 Map<String, Integer>[] kmerIds;
038 Cardinality cardinality = Cardinality.PER_STATE;
039 int rareThreshold = DEFAULT_RARE_THRESHOLD;
040
041 public KmerFeatures(Cardinality cardinality) {
042 this.cardinality = cardinality;
043 this.kmerDefs = Collections.singletonList(new int[] { 0 });
044 }
045
046 public KmerFeatures() {
047 this.kmerDefs = Collections.singletonList(new int[] { 0 });
048 }
049
050 public KmerFeatures(List<int[]> kmerDefs) {
051 this.kmerDefs = kmerDefs;
052 }
053
054 public KmerFeatures(List<int[]> kmerDefs, Cardinality cardinality) {
055 this.cardinality = cardinality;
056 this.kmerDefs = kmerDefs;
057 }
058
059 public void setKmerDefinitions(List<List<Integer>> defs) {
060 kmerDefs = new ArrayList<int[]>();
061 for(List<Integer> def : defs) {
062 int[] kmer = new int[def.size()];
063 kmerDefs.add(kmer);
064 for(int i=0; i<def.size(); ++i) {
065 kmer[i] = def.get(i);
066 }
067 }
068 }
069
070 public void setRareThreshold(int threshold) {
071 Assert.a(model == null, "Can't set threshold after training.");
072 rareThreshold = threshold;
073 }
074
075 public int getNumFeatures() {
076 switch (cardinality) {
077 case SINGLE:
078 return kmerDefs.size();
079 case PER_STATE:
080 return kmerDefs.size() * model.getNumStates();
081 default:
082 Assert.a(cardinality == Cardinality.PER_KMER);
083 int num = 0;
084 for (Map m : counts) {
085 num += m.size();
086 }
087 return num;
088 }
089 }
090
091 public String getFeatureName(int featureIndex) {
092 int raw = featureIndex - startIx;
093 int label = raw / kmerDefs.size();
094 int kmer = raw % kmerDefs.size();
095 String name = kmerName(kmer);
096 String val = "";
097 if (cardinality == Cardinality.PER_STATE) {
098 val = model.getStateName(label) + ".";
099 } else if (cardinality == Cardinality.PER_KMER) {
100 val = "(InsertKMer)";
101 }
102 return "Kmer." + val + name;
103 }
104
105 transient InputSequence<? extends Character> lastSeq;
106 transient int lastPos;
107 String[] kmers;
108 double[][] vals;
109
110 public void evaluateNode(InputSequence<? extends Character> seq, int pos, int state, FeatureList result) {
111 if (seq != lastSeq || pos != lastPos) {
112 lastSeq = seq;
113 lastPos = pos;
114 for (int j = 0; j < kmerDefs.size(); ++j) {
115 String kmer = getKmer(seq, pos, kmerDefs.get(j));
116 vals[j] = counts[j].get(kmer);
117 if (cardinality == Cardinality.PER_KMER) {
118 kmers[j] = kmer;
119 }
120 }
121 }
122 for (int j = 0; j < kmerDefs.size(); ++j) {
123 double[] kmerVals = vals[j];
124 if (kmerVals == null || kmerVals[state] == 0.0)
125 continue;
126 int index;
127 if (cardinality == Cardinality.SINGLE) {
128 index = startIx + j;
129 } else if (cardinality == Cardinality.PER_STATE) {
130 index = startIx + j + kmerDefs.size() * state;
131 } else {
132 Assert.a(cardinality == Cardinality.PER_KMER);
133 Integer id = kmerIds[j].get(kmers[j]);
134 if (id == null)
135 continue;
136 index = id;
137 }
138 result.addFeature(index, kmerVals[state]);
139 }
140 }
141
142 public String getKmer(InputSequence<? extends Character> seq, int pos, int[] def) {
143 StringBuffer buf = new StringBuffer(def.length);
144 for (int i = 0; i < def.length; ++i) {
145 int loc = pos + def[i];
146 if (loc < 0 || loc >= seq.length())
147 return null;
148 buf.append((Character) seq.getX(loc));
149 }
150 return buf.toString();
151 }
152
153 /** Returns a string representation of a given kmer definition */
154 public String kmerName(int index) {
155 int[] kmer = kmerDefs.get(index);
156 StringBuffer ret = new StringBuffer();
157 for (int i = 0; i < kmer.length; ++i) {
158 if (i != 0) {
159 ret.append(".");
160 }
161 ret.append(kmer[i]);
162 }
163 return ret.toString();
164 }
165
166 /** Returns an individual entry from the counts list. */
167 public double getKmerProb(int kmerIndex, String kmer, int label) {
168 return counts[kmerIndex].get(kmer)[label];
169 }
170
171 /** Computes the P(label | kmer) for each kmer across all of the training data. These will used as features values. */
172 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
173 startIx = startingIndex;
174 model = modelInfo;
175 nStates = model.getNumStates();
176 vals = new double[kmerDefs.size()][];
177 // Initialize an array of hash maps for each kmer
178 counts = new HashMap[kmerDefs.size()];
179 for (int i = 0; i < counts.length; ++i) {
180 counts[i] = new HashMap();
181 }
182 if (cardinality == Cardinality.PER_KMER) {
183 kmerIds = new HashMap[kmerDefs.size()];
184 for (int i = 0; i < kmerIds.length; ++i) {
185 kmerIds[i] = new HashMap();
186 }
187 }
188 // Count the occurances of each kmer in each state
189 for (TrainingSequence<? extends Character> seq : data) {
190 int len = seq.length();
191 for (int pos = 0; pos < len; ++pos) {
192 int nKmers = kmerDefs.size();
193 for (int i = 0; i < nKmers; ++i) {
194 String kmer = getKmer(seq, pos, kmerDefs.get(i));
195 if (kmer != null) {
196 double[] val = (double[]) counts[i].get(kmer);
197 if (val == null) {
198 val = new double[nStates];
199 Arrays.fill(val, 1);
200 counts[i].put(kmer, val);
201 }
202 int state = seq.getY(pos);
203 val[state] += 1;
204 }
205 }
206 }
207 }
208 int kmerId = startIx;
209 // Now compute probabilities for all kmers with at least KMER_THRESHOLD appearances
210 for (int j = 0; j < kmerDefs.size(); ++j) {
211 int kmerCounts = 0;
212 for (Map.Entry<String, double[]> val : counts[j].entrySet()) {
213 double[] cnts = val.getValue();
214 double total = 0.0f;
215 for (int i = 0; i < cnts.length; ++i) {
216 total += cnts[i];
217 }
218 if (total >= rareThreshold) {
219 kmerCounts += 1;
220 if (cardinality == Cardinality.PER_KMER) {
221 kmerIds[j].put(val.getKey(), kmerId);
222 kmerId++;
223 }
224 }
225 for (int i = 0; i < cnts.length; ++i) {
226 // Just convert to P(label | kmer)
227 if (total >= rareThreshold) {
228 cnts[i] = Math.log(cnts[i] / total);
229 Assert.a(!Double.isNaN(cnts[i]));
230 Assert.a(!Double.isInfinite(cnts[i]));
231 } else {
232 cnts[i] = 0;
233 }
234 }
235 }
236 log.info(kmerCounts + " kmers of " + kmerName(j));
237 }
238 }
239 }