Ymir  .9
Fast\C++toolforcomputationofassemblingprobabilities,statisticalinferenceofassemblingstatisticalmodelandgenerationofartificialsequencesofT-cellreceptorsdata.
sg_algorithm.h
1 //
2 // Created by Vadim N. on 11/03/2016.
3 //
4 
5 #ifndef YMIR_SG_ALGORITHM_H
6 #define YMIR_SG_ALGORITHM_H
7 
8 
9 #include "em_algorithm.h"
10 
11 
12 namespace ymir {
13 
14  class SGAlgorithm;
15 
16 
22  class SGAlgorithm : public EMAlgorithm {
23  public:
24 
25  virtual std::vector<prob_t> statisticalInference(const ClonesetView& repertoire,
27  const AlgorithmParameters& algo_param = AlgorithmParameters()
28  .set("niter", 10)
29  .set("block.size", 5000)
30  .set("alpha", .6)
31  .set("beta", 1.)
32  .set("K", 2.)
33  .set("prebuild", false)
34  .set("recompute.all", false)
35  .set("sample", 50000),
36  ErrorMode error_mode = NO_ERRORS) const
37  {
38  // shuffle input data at each step
39  // subvec -4008648.1
40 
41  cout << "Statistical inference on a PAM:\t" << model.name() << endl;
42  cout << "\tOnline EM-algorithm.";
43  if (error_mode == COMPUTE_ERRORS) {
44  cout << "\t(with sequence errors)";
45  }
46  std::cout << std::endl;
47 
48 
49  if (!algo_param.check("niter")
50  && !algo_param.check("block.size")
51  && !algo_param.check("alpha")
52  && !algo_param.check("beta")
53  && !algo_param.check("K"))
54  {
55  return std::vector<prob_t>();
56  }
57 
58 // bool prebuild = algo_param.get("prebuild", false).asBool();
59 // bool recompute_all = algo_param.get("recompute.all", false).asBool();
60 
61  std::cout << "\t -- #iterations: " << (size_t) algo_param["niter"].asUInt() << std::endl;
62  std::cout << "\t -- block size: " << (size_t) algo_param["block.size"].asUInt() << std::endl;
63  std::cout << "\t -- alpha: " << (double) algo_param["alpha"].asDouble() << std::endl;
64  std::cout << "\t -- beta: " << (double) algo_param["beta"].asDouble() << std::endl;
65 // std::cout << "\t -- gamma: " << (double) algo_param["gamma"].asDouble() << std::endl;
66  std::cout << "\t -- K: " << (double) algo_param["K"].asDouble() << std::endl;
67 // std::cout << "\t -- prebuild: " << (size_t) algo_param["prebuild"].asBool() << std::endl;
68 // std::cout << "\t -- recomp. logL:" << (size_t) algo_param["recompute.all"].asBool() << std::endl;
69 
70  std::vector<prob_t> logLvec;
71 
72 
73  size_t sample = algo_param["sample"].asUInt();
74  ClonesetView rep_nonc = repertoire.noncoding().sample(sample);
75 
76 // rep_nonc = rep_nonc.sample(algo_param.get("sample", (Json::Value::UInt64) rep_nonc.size()).asUInt64()); // TODO: CHECK IS IT OK TO ASSIGN TO ITSELF?
77  cout << "Number of noncoding clonotypes:\t" << (size_t) rep_nonc.size() << endl;
78 
79 
80  size_t start_i = rep_nonc.size();
81  size_t block_size = algo_param["block.size"].asUInt();
82  prob_t alpha = algo_param["alpha"].asDouble(); // step(k) = (k + 2)^(-alpha), .5 < alpha <= 1
83  prob_t beta = algo_param["beta"].asDouble();
84  prob_t Kparam = algo_param["K"].asDouble(); // step(k) = (k + 2)^(-alpha), .5 < alpha <= 1
85  ModelParameterVector new_param_vec = model.event_probabilities();
86  new_param_vec.fill(1);
87  new_param_vec.set_error_prob(.0003);
88  new_param_vec.normaliseEventFamilies();
89  model.updateModelParameterVector(new_param_vec);
90 
91  std::vector<bool> changed(new_param_vec.size(), false);
92 
93  auto maag_rep = model.buildGraphs(rep_nonc, SAVE_METADATA, error_mode, NUCLEOTIDE, true);
94  std::cout << "MAAG rep size" << (size_t) maag_rep.size() << std::endl;
95 
96  std::vector<prob_t> prob_vec(maag_rep.size(), 0);
97  prob_t prev_ll = 0;
98 
99  cout << "Computing full assembling probabilities..." << endl;
100  vector<bool> good_clonotypes;
101  size_t removed, zero_prob, no_alignments;
102  this->filterOut(rep_nonc, maag_rep, prob_vec, good_clonotypes, removed, zero_prob, no_alignments);
103 
104  std::vector<size_t> indices;
105  for (size_t i = 0; i < maag_rep.size(); ++i) {
106  if (good_clonotypes[i]) {
107  indices.push_back(i);
108  }
109  }
110  std::cout << "MAAGs in work:\t" << (size_t) indices.size() << endl;
111  std::random_shuffle(indices.begin(), indices.end());
112 
113  cout << endl << "Initial data summary:" << endl;
114  prob_summary(prob_vec);
115  std::cout << model.event_probabilities().error_prob() << std::endl;
116  prev_ll = loglikelihood(prob_vec);
117  logLvec.push_back(prev_ll);
118 
120  for (size_t iter = 1; iter <= algo_param["niter"].asUInt(); ++iter) {
121  if (start_i + block_size > indices.size()) {
122  start_i = 0;
123  std::random_shuffle(indices.begin(), indices.end());
124  } else {
125  start_i += block_size;
126  }
127 
128  std::cout << "=======================" << std::endl
129  << "Iteration: " << (size_t) iter << " Block: [" << (int) start_i << ":" << (int) std::min(indices.size() - 1, start_i + block_size - 1)
130  << "]" << std::endl << "=======================" << std::endl;
131 
132  new_param_vec.fill(0);
133 
134  for (size_t maag_i = start_i; maag_i < std::min(indices.size(), start_i + block_size); ++maag_i) {
135 // cout << "start of the iteration" << endl;
136  // compute marginal probabilities for this block
137  // and update the temporary model parameter vector
138 // std::cout << (size_t) maag_i << " / " << (size_t) (std::min(indices.size(), start_i + block_size)) << std::endl;
139 // std::cout << (size_t) indices[maag_i] << std::endl;
140 // std::cout << (size_t) indices[maag_i] << std::endl;
141 // std::cout << rep_nonc[indices[maag_i]].toString() << std::endl;
142 // std::cout << rep_nonc[indices[maag_i]].is_good() << std::endl;
143 // std::cout << good_clonotypes[indices[maag_i]] << std::endl;
144 // std::cout << prob_vec[indices[maag_i]] << std::endl;
145 // std::cout << maag_rep[indices[maag_i]].has_errors() << std::endl;
146 // std::cout << maag_rep[indices[maag_i]].has_events() << std::endl;
147  this->updateTempVec(fb, maag_rep[indices[maag_i]], new_param_vec, changed, error_mode);
148 // return vector<prob_t>();
149 // cout << "end of the iteration" << endl;
150  }
151 
152  std::cout << "Err:" << new_param_vec.error_prob() << std::endl;
153  this->updateModel(model, new_param_vec, maag_rep, prob_vec, prev_ll, changed, pow(beta*iter + Kparam, -alpha), error_mode);
154 
155  logLvec.push_back(prev_ll);
156  }
157 
158  return logLvec;
159  }
160 
161  protected:
162 
163  void updateTempVec(MAAGForwardBackwardAlgorithm &fb,
164  MAAG &maag,
165  ModelParameterVector &new_param_vec,
166  vector<bool> &changed,
167  ErrorMode error_mode) const
168  {
169  event_pair_t ep;
170  while (!fb.is_empty()) {
171  ep = fb.nextEvent();
172  new_param_vec[ep.first] += ep.second;
173  changed[ep.first] = true;
174  }
175 
176  if (error_mode) {
177  new_param_vec.set_error_prob(new_param_vec.error_prob() + fb.err_prob());
178  }
179 
180  if (maag.is_vj()) {
181  new_param_vec[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 0)] += fb.VJ_nuc_probs()[0];
182  new_param_vec[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 1)] += fb.VJ_nuc_probs()[1];
183  new_param_vec[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 2)] += fb.VJ_nuc_probs()[2];
184  new_param_vec[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 3)] += fb.VJ_nuc_probs()[3];
185 
186  changed[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 0)] = true;
187  changed[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 1)] = true;
188  changed[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 2)] = true;
189  changed[new_param_vec.event_index(VJ_VAR_JOI_INS_NUC, 0, 3)] = true;
190  } else {
191  int k_vd = new_param_vec.event_index(VDJ_VAR_DIV_INS_NUC, 0, 0),
192  k_dj = new_param_vec.event_index(VDJ_DIV_JOI_INS_NUC, 0, 0);
193 
194  // it's not working in Linux Fedora g++
195 // for (auto i = 0; i < 16; ++i) {
196 // std::cout << (int) i << std::endl;
197 // new_param_vec[i + k_vd] += fb.VD_nuc_probs()[i];
198 // changed[i + k_vd] = true;
199 //
200 // new_param_vec[i + k_dj] += fb.DJ_nuc_probs()[i];
201 // changed[i + k_dj] = true;
202 // }
203 
204  new_param_vec[k_vd] += fb.VD_nuc_probs()[0];
205  new_param_vec[k_vd + 1] += fb.VD_nuc_probs()[1];
206  new_param_vec[k_vd + 2] += fb.VD_nuc_probs()[2];
207  new_param_vec[k_vd + 3] += fb.VD_nuc_probs()[3];
208  new_param_vec[k_vd + 4] += fb.VD_nuc_probs()[4];
209  new_param_vec[k_vd + 5] += fb.VD_nuc_probs()[5];
210  new_param_vec[k_vd + 6] += fb.VD_nuc_probs()[6];
211  new_param_vec[k_vd + 7] += fb.VD_nuc_probs()[7];
212  new_param_vec[k_vd + 8] += fb.VD_nuc_probs()[8];
213  new_param_vec[k_vd + 9] += fb.VD_nuc_probs()[9];
214  new_param_vec[k_vd + 10] += fb.VD_nuc_probs()[10];
215  new_param_vec[k_vd + 11] += fb.VD_nuc_probs()[11];
216  new_param_vec[k_vd + 12] += fb.VD_nuc_probs()[12];
217  new_param_vec[k_vd + 13] += fb.VD_nuc_probs()[13];
218  new_param_vec[k_vd + 14] += fb.VD_nuc_probs()[14];
219  new_param_vec[k_vd + 15] += fb.VD_nuc_probs()[15];
220 
221  changed[k_vd] = true;
222  changed[k_vd + 1] = true;
223  changed[k_vd + 2] = true;
224  changed[k_vd + 3] = true;
225  changed[k_vd + 4] = true;
226  changed[k_vd + 5] = true;
227  changed[k_vd + 6] = true;
228  changed[k_vd + 7] = true;
229  changed[k_vd + 8] = true;
230  changed[k_vd + 9] = true;
231  changed[k_vd + 10] = true;
232  changed[k_vd + 11] = true;
233  changed[k_vd + 12] = true;
234  changed[k_vd + 13] = true;
235  changed[k_vd + 14] = true;
236  changed[k_vd + 15] = true;
237 
238  new_param_vec[k_dj] += fb.DJ_nuc_probs()[0];
239  new_param_vec[k_dj + 1] += fb.DJ_nuc_probs()[1];
240  new_param_vec[k_dj + 2] += fb.DJ_nuc_probs()[2];
241  new_param_vec[k_dj + 3] += fb.DJ_nuc_probs()[3];
242  new_param_vec[k_dj + 4] += fb.DJ_nuc_probs()[4];
243  new_param_vec[k_dj + 5] += fb.DJ_nuc_probs()[5];
244  new_param_vec[k_dj + 6] += fb.DJ_nuc_probs()[6];
245  new_param_vec[k_dj + 7] += fb.DJ_nuc_probs()[7];
246  new_param_vec[k_dj + 8] += fb.DJ_nuc_probs()[8];
247  new_param_vec[k_dj + 9] += fb.DJ_nuc_probs()[9];
248  new_param_vec[k_dj + 10] += fb.DJ_nuc_probs()[10];
249  new_param_vec[k_dj + 11] += fb.DJ_nuc_probs()[11];
250  new_param_vec[k_dj + 12] += fb.DJ_nuc_probs()[12];
251  new_param_vec[k_dj + 13] += fb.DJ_nuc_probs()[13];
252  new_param_vec[k_dj + 14] += fb.DJ_nuc_probs()[14];
253  new_param_vec[k_dj + 15] += fb.DJ_nuc_probs()[15];
254 
255  changed[k_dj] = true;
256  changed[k_dj + 1] = true;
257  changed[k_dj + 2] = true;
258  changed[k_dj + 3] = true;
259  changed[k_dj + 4] = true;
260  changed[k_dj + 5] = true;
261  changed[k_dj + 6] = true;
262  changed[k_dj + 7] = true;
263  changed[k_dj + 8] = true;
264  changed[k_dj + 9] = true;
265  changed[k_dj + 10] = true;
266  changed[k_dj + 11] = true;
267  changed[k_dj + 12] = true;
268  changed[k_dj + 13] = true;
269  changed[k_dj + 14] = true;
270  changed[k_dj + 15] = true;
271  }
272  }
273 
274 
275  void updateModel(ProbabilisticAssemblingModel &model,
276  ModelParameterVector &new_param_vec,
277  MAAGRepertoire &maag_rep,
278  vector<prob_t> &prob_vec,
279  prob_t &prev_ll,
280  vector<bool> &changed,
281  prob_t step_k,
282  ErrorMode error_mode) const
283  {
284 // if (error_mode) { new_param_vec.set_error_prob(new_param_vec.error_prob() / maag_rep.size()); }
285  if (error_mode) { new_param_vec.set_error_prob(.0003); }
286 
287  new_param_vec.normaliseEventFamilies();
288 
289  for (size_t i = 0; i < new_param_vec.size(); ++i) {
290  if (!changed[i]) {
291  new_param_vec[i] = model.event_probabilities()[i];
292  } else {
293  new_param_vec[i] = step_k * model.event_probabilities()[i] + (1 - step_k) * new_param_vec[i];
294  }
295  }
296 
297  new_param_vec.normaliseEventFamilies();
298 
299  model.updateModelParameterVector(new_param_vec);
300  model.updateEventProbabilities(&maag_rep, false);
301 
302  for (size_t i = 0; i < maag_rep.size(); ++i) {
303  prob_vec[i] = maag_rep[i].fullProbability();
304  }
305  prob_summary(prob_vec, prev_ll);
306  prev_ll = loglikelihood(prob_vec);
307 
308  changed.clear();
309  changed.resize(new_param_vec.size(), false);
310  }
311 
312  };
313 
314 }
315 
316 #endif //YMIR_SG_ALGORITHM_H
Definition: aligner.h:37
Definition: sg_algorithm.h:22
event_ind_t event_index(EventClass event_class, event_ind_t event_family, event_ind_t event_index) const
Definition: modelparametervector.h:261
Definition: maagforwardbackwardalgorithm.h:20
Multi-Alignment Assembly Graph - basic class for representing all possible generation scenarios of a ...
Definition: maag.h:46
const ModelParameterVector & event_probabilities() const
Access to vector of probabilities.
Definition: probabilisticassemblingmodel.h:185
void updateEventProbabilities(MAAGRepertoire *repertoire, bool verbose=true)
Update event probabilities in the given MAAG repertoire with new ones.
Definition: probabilisticassemblingmodel.h:143
Definition: statisticalinferencealgorithm.h:49
void normaliseEventFamilies()
Normalise each event family to have sum equal to 1.
Definition: modelparametervector.h:316
MAAGRepertoire buildGraphs(const ClonesetView &repertoire, MetadataMode save_metadata, ErrorMode error_mode, SequenceType sequence_type=NUCLEOTIDE, bool verbose=true) const
Build a set of MAAGs from the given cloneset.
Definition: probabilisticassemblingmodel.h:126
Class for storing parameters of assembling statistical model. Note: event with index 0 (zero) is "nul...
Definition: modelparametervector.h:68
Definition: probabilisticassemblingmodel.h:41
Definition: repertoire.h:51
void fill(prob_t val=0)
Fill the vector with the given value.
Definition: modelparametervector.h:377
void updateModelParameterVector(const ModelParameterVector &vec)
Set a new event probabilities vector to this model.
Definition: probabilisticassemblingmodel.h:199
Implementation of the EM-algorithm for statistical inference of assembling model parameters. Classic version described in (Murugan et al 2012)
Definition: em_algorithm.h:23