Ymir  .9
Fast\C++toolforcomputationofassemblingprobabilities,statisticalinferenceofassemblingstatisticalmodelandgenerationofartificialsequencesofT-cellreceptorsdata.
maagforwardbackwardalgorithm.h
1 //
2 // Created by Vadim N. on 20/04/2015.
3 //
4 
5 #ifndef YMIR_MAAGFORWARDBACKWARDALGORITHM_H
6 #define YMIR_MAAGFORWARDBACKWARDALGORITHM_H
7 
8 #include <unordered_map>
9 
10 #include "maag.h"
11 
12 namespace ymir {
13 
14  class MAAGForwardBackwardAlgorithm;
15  class ForwardBackwardAlgorithm;
16  class VJForwardBackward;
17  class VDJForwardBackward;
18 
19 
21 
22  protected:
23 
24  typedef ProbMMC::dim_t dim_t;
25  typedef ProbMMC::matrix_ind_t matrix_ind_t;
26  typedef ProbMMC::node_ind_t node_ind_t;
27 
28  public:
29 
30 
32  : _pairs_i(0),
33  _status(false),
34  _vectorised(false),
35  _err_prob(0)
36  {
37  }
38 
39 
40  MAAGForwardBackwardAlgorithm(const MAAG &maag, ErrorMode error_mode = NO_ERRORS) {
41  process(maag, error_mode);
42  }
43 
44 
46  {
47  }
48 
49 
53  bool process(const MAAG &maag, ErrorMode error_mode = NO_ERRORS) {
54  _pairs_i = 0;
55  _status = false;
56  _vectorised = false;
57  _pairs.clear();
58  _pairs.reserve(maag._chain.size() + 10);
59  _full_prob = 0;
60  _back_full_prob = 0;
61  _err_prob = 0;
62  _err_mode = error_mode;
63 
64  if (_err_mode == COMPUTE_ERRORS && !maag.has_errors()) {
65  cerr << "MAAG forward-backward algorithm error: no error matrix has been found in the input MAAG." << endl;
66  _status = false;
67  return false;
68  } else {
69  if (maag.has_events()) {
70  _status = true;
71  if (maag.recombination() == VJ_RECOMB) {
72  fill(_nuc_arr1, _nuc_arr1 + 4, 0);
73  this->forward_backward_vj(maag);
74  } else if (maag.recombination() == VDJ_RECOMB) {
75  fill(_nuc_arr1, _nuc_arr1 + 16, 0);
76  fill(_nuc_arr2, _nuc_arr2 + 16, 0);
77  this->forward_backward_vdj(maag);
78  } else {
79  cerr << "MAAG forward-backward algorithm error: unknown recombination type." << endl;
80  _status = false;
81  }
82  } else {
83  cerr << "MAAG forward-backward algorithm error: no event matrix has been found in the input MAAG." << endl;
84  _status = false;
85  }
86  }
87 
88  return _status;
89  }
90 
91 
92  event_pair_t nextEvent() {
93  if (_status) {
94  event_pair_t res = _pairs[_pairs_i];
95 // cout << res.second << " -> ";
96 // res.second = res.second / _full_prob;
97 // cout << res.second << endl;
98  ++_pairs_i;
99  return res;
100  }
101 
102  _status = false;
103  return event_pair_t(0, 0);
104  }
105 
106 
107  bool is_empty() {
108  if (_status) {
109  if (_pairs_i != _pairs.size()) {
110  return false;
111  }
112  }
113  return true;
114  }
115 
116 
117  bool status() const { return _status; }
118 
119 
120  prob_t fullProbability() const { return _full_prob; }
121 
122 
123  prob_t bfullProbability() const { return _back_full_prob; }
124 
125 
126  const vector<event_pair_t>& event_pairs() const { return _pairs; }
127 
128 
129  prob_t* VJ_nuc_probs() { return _nuc_arr1; }
130  prob_t* VD_nuc_probs() { return _nuc_arr1; }
131  prob_t* DJ_nuc_probs() { return _nuc_arr2; }
132 
133  prob_t err_prob() { return _err_prob; }
134 
135  protected:
136 
137  bool _status;
138  bool _vectorised;
139  pProbMMC _forward_acc, _backward_acc;
140  prob_t _full_prob;
143  vector<event_pair_t> _pairs;
144  size_t _pairs_i;
145  unordered_map<event_ind_t, prob_t> _pair_map;
146  prob_t _nuc_arr1[16], _nuc_arr2[16];
147  prob_t _err_prob;
148  ErrorMode _err_mode;
149 
150 
154  void fillZero(ProbMMC *mmc, uint start_node = 0) {
155  for (uint i = start_node; i < mmc->chainSize(); ++i) {
156  for (uint j = 0; j < mmc->nodeSize(i); ++j) {
157  mmc->fill(i, j, 0);
158  }
159  }
160  }
161 
162 
166  void inferInsertionNucleotides(const MAAG &maag, node_ind_t ins_node,
167  seq_len_t left_start_pos, seq_len_t left_end_pos,
168  seq_len_t right_start_pos, seq_len_t right_end_pos,
169  prob_t *nuc_arr, bool reversed = false)
170  {
171  node_ind_t forw_node = ins_node - 1, back_node = ins_node;
172  prob_t scenario_prob = 0;
173  prob_t temp_arr[16];
174  uint n = 0;
175  seq_len_t left_pos, right_pos, start_shift;
176 
177 //#ifdef YDEBUG
178 // if (left_end_pos - left_start_pos + 1 != maag.nodeRows(forw_node)) { throw(std::runtime_error("Wrong position boundaries (forward node)!")); }
179 // if (right_end_pos - right_start_pos + 1 != maag.nodeColumns(back_node)) { throw(std::runtime_error("Wrong position boundaries (backward node)!")); }
180 //#endif
181 
182  if (maag.is_vj()) {
183  for (dim_t row_i = 0; row_i < maag.nodeRows(ins_node); ++row_i) {
184  left_pos = left_start_pos + row_i;
185  for (dim_t col_i = 0; col_i < maag.nodeColumns(ins_node); ++col_i) {
186  right_pos = right_start_pos + col_i;
187 
188  if (maag.position(right_pos) - maag.position(left_pos) - 1 > 0 && maag.event_index(ins_node, 0, row_i, col_i)) {
189  fill(temp_arr, temp_arr + 4, 0);
190  n = 0;
191 
192  scenario_prob = (*_forward_acc)(ins_node, 0, row_i, col_i)
193  * (*_backward_acc)(back_node, 0, row_i, col_i);
194 
195  if (_err_mode == NO_ERRORS) {
196  for (seq_len_t pos = maag.position(left_pos) + 1; pos < maag.position(right_pos); ++pos) {
197  temp_arr[nuc_hash(maag.sequence()[pos - 1])] += 1;
198  ++n;
199  }
200  } else {
201  for (seq_len_t pos = maag.position(left_pos) + 1; pos < maag.position(right_pos); ++pos) {
202  temp_arr[0] += _err_prob / 3;
203  temp_arr[1] += _err_prob / 3;
204  temp_arr[2] += _err_prob / 3;
205  temp_arr[3] += _err_prob / 3;
206  temp_arr[nuc_hash(maag.sequence()[pos - 1])] += (1 - _err_prob / 3);
207  ++n;
208  }
209 
210  _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
211  }
212 
213 // std::cout << "nucs" << std::endl;
214  for (auto i = 0; i < 4; ++i) {
215 // std::cout << nuc_arr[i] << ":" << temp_arr[i] << ":" << scenario_prob << ":" << (int) n << std::endl;
216  nuc_arr[i] += (temp_arr[i] * scenario_prob) / n;
217 // std::cout << nuc_arr[i] << std::endl;
218 // check_and_throw(std::isnan(nuc_arr[i]), "nan in nucs");
219  }
220  }
221  }
222  }
223  } else {
224  for (dim_t row_i = 0; row_i < maag.nodeRows(ins_node); ++row_i) {
225  left_pos = left_start_pos + row_i;
226  for (dim_t col_i = 0; col_i < maag.nodeColumns(ins_node); ++col_i) {
227  right_pos = right_start_pos + col_i;
228 
229  if (maag.position(right_pos) - maag.position(left_pos) - 1 > 0) {
230  fill(temp_arr, temp_arr + 16, 0);
231  n = 0;
232  start_shift = 0;
233 
234  scenario_prob = (*_forward_acc)(ins_node, 0, row_i, col_i)
235  * (*_backward_acc)(back_node, 0, row_i, col_i);
236 
237  if (!reversed) {
238  if (maag.position(left_pos) == 0) {
239  start_shift = 1;
240  temp_arr[4 * nuc_hash('A') + nuc_hash(maag.sequence()[0])] = .25;
241  temp_arr[4 * nuc_hash('C') + nuc_hash(maag.sequence()[0])] = .25;
242  temp_arr[4 * nuc_hash('G') + nuc_hash(maag.sequence()[0])] = .25;
243  temp_arr[4 * nuc_hash('T') + nuc_hash(maag.sequence()[0])] = .25;
244  ++n;
245  }
246 
247  if (_err_mode == NO_ERRORS) {
248  for (seq_len_t pos = maag.position(left_pos) + start_shift + 1; pos < maag.position(right_pos); ++pos) {
249  temp_arr[4 * nuc_hash(maag.sequence()[pos - 2]) + nuc_hash(maag.sequence()[pos - 1])] += 1;
250  ++n;
251  }
252  } else {
253  for (seq_len_t pos = maag.position(left_pos) + start_shift + 1; pos < maag.position(right_pos); ++pos) {
254  for (int i = 0; i < 16; ++i) {
255  temp_arr[0] += _err_prob * _err_prob / 15;
256  }
257  temp_arr[4 * nuc_hash(maag.sequence()[pos - 2]) + nuc_hash(maag.sequence()[pos - 1])] += (1 - _err_prob * _err_prob / 15);
258  ++n;
259  }
260 
261  _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
262  }
263  } else {
264  if (maag.position(right_pos) == maag.sequence().size() + 1) {
265  start_shift = 1;
266  temp_arr[4 * nuc_hash('A') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
267  temp_arr[4 * nuc_hash('C') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
268  temp_arr[4 * nuc_hash('G') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
269  temp_arr[4 * nuc_hash('T') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
270  ++n;
271  }
272 
273  if (_err_mode == NO_ERRORS) {
274  for (seq_len_t pos = maag.position(right_pos) - start_shift; pos > maag.position(left_pos) + 1; --pos) {
275  temp_arr[4 * nuc_hash(maag.sequence()[pos - 1]) + nuc_hash(maag.sequence()[pos - 2])] += 1;
276  ++n;
277  }
278  } else {
279  for (seq_len_t pos = maag.position(right_pos) - start_shift; pos > maag.position(left_pos) + 1; --pos) {
280  for (int i = 0; i < 16; ++i) {
281  temp_arr[0] += _err_prob * _err_prob / 15;
282  }
283  temp_arr[4 * nuc_hash(maag.sequence()[pos - 1]) + nuc_hash(maag.sequence()[pos - 2])] += (1 - _err_prob * _err_prob / 15);
284  ++n;
285  }
286 
287  _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
288  }
289  }
290 
291  for (auto i = 0; i < 16; ++i) {
292  nuc_arr[i] += (temp_arr[i] * scenario_prob) / n;
293  }
294  }
295  }
296  }
297  }
298  }
299 
300 
304  void pushEventValue(event_ind_t event_index, prob_t prob_value) {
306  if (prob_value && event_index) {
307 
308  auto elem = _pair_map.find(event_index);
309 
310  if (elem == _pair_map.end()) {
311  _pair_map[event_index] = 0;
312  }
313  _pair_map[event_index] += prob_value;
314  }
315  }
316 
317  void pushEventPair(const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, dim_t maag_row_i, dim_t maag_col_i,
318  matrix_ind_t fb_mat_i, dim_t fb_row_i, dim_t fb_col_i) {
319  this->pushEventValue(maag.event_index(node_i, maag_mat_i, maag_row_i, maag_col_i),
320  (*_forward_acc)(node_i, fb_mat_i, fb_row_i, fb_col_i) * (*_backward_acc)(node_i, fb_mat_i, fb_row_i, fb_col_i));
321  }
322 
323  void pushEventPairs(const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, matrix_ind_t fb_mat_i) {
324  for (dim_t row_i = 0; row_i < maag.nodeRows(node_i); ++row_i) {
325  for (dim_t col_i = 0; col_i < maag.nodeColumns(node_i); ++col_i) {
326  this->pushEventPair(maag, node_i, maag_mat_i, row_i, col_i, fb_mat_i, row_i, col_i);
327  }
328  }
329  }
330 
331  void pushEventPairsWithErrors(const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, matrix_ind_t fb_mat_i, node_ind_t err_node_i) {
332  this->pushEventPairs(maag, node_i, maag_mat_i, fb_mat_i);
333 
334  if (_err_mode == COMPUTE_ERRORS) {
335  for (dim_t row_i = 0; row_i < maag.nodeRows(node_i); ++row_i) {
336  for (dim_t col_i = 0; col_i < maag.nodeColumns(node_i); ++col_i) {
337  if (maag.errors(err_node_i, fb_mat_i, row_i, col_i)) {
338  _err_prob += (*_forward_acc)(node_i, fb_mat_i, row_i, col_i)
339  * (*_backward_acc)(node_i, fb_mat_i, row_i, col_i)
340  * (maag.errors(err_node_i, fb_mat_i, row_i, col_i) / maag.n_poses());
341  }
342  }
343  }
344  }
345  }
347 
348 
352  void vectorise_pair_map(const MAAG &maag) {
353  if (maag.is_vj()) {
354  _nuc_arr1[0] /= _full_prob;
355  _nuc_arr1[1] /= _full_prob;
356  _nuc_arr1[2] /= _full_prob;
357  _nuc_arr1[3] /= _full_prob;
358  } else {
359  for (int i = 0; i < 16; ++i) {
360  _nuc_arr1[i] /= _full_prob;
361  }
362  for (int i = 0; i < 16; ++i) {
363  _nuc_arr2[i] /= _full_prob;
364  }
365  }
366 
367  _pairs.reserve(_pair_map.size() + 40);
368  for (auto it = _pair_map.begin(); it != _pair_map.end(); ++it) {
369  _pairs.push_back(event_pair_t(it->first, it->second / _full_prob));
370  }
371  _pair_map.clear();
372  }
373 
374 
375  //
376  // Forward-backward algorithm for VJ recombination receptors
377  //
378 
379  // make a matrix chain with forward probabilities for VJ receptors
380  void forward_vj(const MAAG &maag, event_ind_t j_ind) {
381  this->fillZero(_forward_acc.get());
382 
383  // VJ probabilities for the fixed J
384  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_GEN_I); ++row_i) {
385  (*_forward_acc)(VJ_VAR_JOI_GEN_I, 0, row_i, 0) = maag(VJ_VAR_JOI_GEN_I, 0, row_i, j_ind);
386  }
387 
388  // V deletions
389  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
390  for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
391  (*_forward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) =
392  (*_forward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) * maag(VJ_VAR_DEL_I, v_ind, 0, col_i);
393  }
394  }
395 
396  // VJ insertions
397  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_i) {
398  prob_t temp_prob = 0;
399  // sum of fi for V del
400  for (dim_t v_i = 0; v_i < maag.nodeSize(VJ_VAR_DEL_I); ++v_i) {
401  temp_prob += (*_forward_acc)(VJ_VAR_DEL_I, v_i, 0, row_i);
402  }
403 
404  for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++col_i) {
405  (*_forward_acc)(VJ_VAR_JOI_INS_I, 0, row_i, col_i) =
406  temp_prob * maag(VJ_VAR_JOI_INS_I, 0, row_i, col_i);
407  }
408  }
409 
410  // J deletions
411  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
412  for (dim_t row_vj_i = 0; row_vj_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_vj_i) {
413  (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) += (*_forward_acc)(VJ_VAR_JOI_INS_I, 0, row_vj_i, row_i);
414  }
415  (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) *= maag(VJ_JOI_DEL_I, j_ind, row_i, 0);
416  }
417 
418 
419  // update the full generation probability
420  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
421  _full_prob += (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0);
422  }
423  }
424 
425 
426  // make a matrix chain with backward probabilities for VJ receptors
427  void backward_vj(const MAAG &maag, event_ind_t j_ind) {
428  this->fillZero(_backward_acc.get());
429 
430  // J deletions
431  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
432  (*_backward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) = 1;
433  }
434 
435  // VJ insertions
436  for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++col_i) {
437  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_i) {
438  (*_backward_acc)(VJ_VAR_JOI_INS_I, 0, row_i, col_i) += maag(VJ_JOI_DEL_I, j_ind, col_i, 0);
439  }
440  }
441 
442  // V deletions
443  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
444  for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
445  for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++ins_col_i) {
446  (*_backward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) +=
447  (*_backward_acc)(VJ_VAR_JOI_INS_I, 0, col_i, ins_col_i) * maag(VJ_VAR_JOI_INS_I, 0, col_i, ins_col_i);
448  }
449  }
450  }
451 
452  // V-J genes
453  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
454  for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
455  (*_backward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) +=
456  (*_backward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) * maag(VJ_VAR_DEL_I, v_ind, 0, col_i);
457  }
458  }
459 
460  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
461  _back_full_prob +=
462  (*_backward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) * maag(VJ_VAR_JOI_GEN_I, 0, v_ind, j_ind);
463  }
464  }
465 
466 
467  // compute all forward-backward probabilities
468  void forward_backward_vj(const MAAG &maag) {
469  _forward_acc.reset(new ProbMMC());
470  _forward_acc->resize(maag.chainSize());
471  // VJ probabilities (for fixed J in future)
472  _forward_acc->initNode(VJ_VAR_JOI_GEN_I, 1, maag.nodeRows(VJ_VAR_JOI_GEN_I), 1);
473  // V deletions
474  _forward_acc->initNode(VJ_VAR_DEL_I, maag.nodeSize(VJ_VAR_DEL_I), 1, maag.nodeColumns(VJ_VAR_DEL_I));
475  // VJ insertions
476  _forward_acc->initNode(VJ_VAR_JOI_INS_I, 1, maag.nodeRows(VJ_VAR_JOI_INS_I), maag.nodeColumns(VJ_VAR_JOI_INS_I));
477  // J deletions (for fixed J in future)
478  _forward_acc->initNode(VJ_JOI_DEL_I, 1, maag.nodeRows(VJ_JOI_DEL_I), 1);
479 
480  _backward_acc.reset(new ProbMMC());
481  _backward_acc->resize(maag.chainSize());
482  // VJ probabilities (for fixed J in future)
483  _backward_acc->initNode(VJ_VAR_JOI_GEN_I, 1, maag.nodeRows(VJ_VAR_JOI_GEN_I), 1);
484  // V deletions
485  _backward_acc->initNode(VJ_VAR_DEL_I, maag.nodeSize(VJ_VAR_DEL_I), 1, maag.nodeColumns(VJ_VAR_DEL_I));
486  // VJ insertions
487  _backward_acc->initNode(VJ_VAR_JOI_INS_I, 1, maag.nodeRows(VJ_VAR_JOI_INS_I), maag.nodeColumns(VJ_VAR_JOI_INS_I));
488  // J deletions (for fixed J in future)
489  _backward_acc->initNode(VJ_JOI_DEL_I, 1, maag.nodeRows(VJ_JOI_DEL_I), 1);
490 
491  // Compute fi * bi / Pgen for each event.
492  for (event_ind_t j_ind = 0; j_ind < maag.nJoi(); ++j_ind) {
493  // compute forward and backward probabilities for a specific J gene
494  this->forward_vj(maag, j_ind);
495  this->backward_vj(maag, j_ind);
496 
497  // add fi * bi for this J to the accumulator
498  for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_GEN_I); ++row_i) {
499  this->pushEventPair(maag, VJ_VAR_JOI_GEN_I, 0, row_i, j_ind, 0, row_i, 0);
500  }
501  for (matrix_ind_t mat_i = 0; mat_i < maag.nVar(); ++mat_i) {
502  this->pushEventPairsWithErrors(maag, VJ_VAR_DEL_I, mat_i, mat_i, 0);
503  }
504  this->pushEventPairs(maag, VJ_VAR_JOI_INS_I, 0, 0);
505  this->pushEventPairsWithErrors(maag, VJ_JOI_DEL_I, j_ind, 0, 1);
506 
507  this->inferInsertionNucleotides(maag, VJ_VAR_JOI_INS_I,
508  0, maag.nodeColumns(VJ_VAR_DEL_I) - 1,
509  maag.nodeColumns(VJ_VAR_DEL_I), maag.nodeColumns(VJ_VAR_DEL_I) + maag.nodeRows(VJ_JOI_DEL_I) - 1,
510  _nuc_arr1);
511  }
512 
513  this->vectorise_pair_map(maag);
514  }
515 
516 
517  //
518  // Forward-backward algorithm for VDJ recombination receptors
519  //
520 
521  // make a matrix chain with forward probabilities for VDJ receptors
522  void forward_vdj(const MAAG &maag, event_ind_t d_ind, event_ind_t j_ind, bool recompute_d_gen_fi) {
523  if (recompute_d_gen_fi) {
524  // forward probabilities for (V prob -> V del -> VD ins) are fixed
525  // for all pairs of J-D.
526  // We have already stored in _forward_acc fi for this D, so we don't
527  // need to recompute entire _forward_acc, we just need to recompute
528  // J deletions and J genes fi.
529  this->fillZero(_forward_acc.get(), VDJ_DIV_DEL_I);
530 
531  // D deletions
532  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++row_i) {
533  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++col_i) {
534  for (dim_t ins_row_i = 0; ins_row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++ins_row_i) {
535  (*_forward_acc)(VDJ_DIV_DEL_I, 0, row_i, col_i) +=
536  (*_forward_acc)(VDJ_VAR_DIV_INS_I, 0, ins_row_i, row_i) * maag(VDJ_DIV_DEL_I, d_ind, row_i, col_i);
537  }
538  }
539  }
540 
541  // DJ insertions
542  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++row_i) {
543  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++col_i) {
544  for (dim_t dgen_row_i = 0; dgen_row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++dgen_row_i) {
545  (*_forward_acc)(VDJ_DIV_JOI_INS_I, 0, row_i, col_i) +=
546  (*_forward_acc)(VDJ_DIV_DEL_I, 0, dgen_row_i, row_i) * maag(VDJ_DIV_JOI_INS_I, 0, row_i, col_i);
547  }
548  }
549  }
550 
551  } else {
552  this->fillZero(_forward_acc.get(), VDJ_JOI_DEL_I);
553  }
554 
555  // J deletions
556  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
557  for (dim_t ins_row_i = 0; ins_row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++ins_row_i) {
558  (*_forward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) +=
559  (*_forward_acc)(VDJ_DIV_JOI_INS_I, 0, ins_row_i, row_i) * maag(VDJ_JOI_DEL_I, j_ind, row_i, 0);
560  }
561  }
562 
563  // J-D genes
564  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
565  (*_forward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0) +=
566  (*_forward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) * maag(VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind);
567  }
568 
569  // update the full generation probability
570  _full_prob += (*_forward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0);
571  }
572 
573 
574  // make a matrix chain with backward probabilities for VDJ receptors
575  void backward_vdj(const MAAG &maag, event_ind_t d_ind, event_ind_t j_ind) {
576  this->fillZero(_backward_acc.get());
577 
578  // J-D pairs
579  (*_backward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0) = 1;
580 
581  // J deletions
582  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
583  (*_backward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) += maag(VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind);
584  }
585 
586  // DJ insertions
587  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++row_i) {
588  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++col_i) {
589  (*_backward_acc)(VDJ_DIV_JOI_INS_I, 0, row_i, col_i) +=
590  (*_backward_acc)(VDJ_JOI_DEL_I, 0, col_i, 0) * maag(VDJ_JOI_DEL_I, j_ind, col_i, 0);
591  }
592  }
593 
594  // D5'-3' deletions
595  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++row_i) {
596  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++col_i) {
597  for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++ins_col_i) {
598  (*_backward_acc)(VDJ_DIV_DEL_I, 0, row_i, col_i) +=
599  (*_backward_acc)(VDJ_DIV_JOI_INS_I, 0, col_i, ins_col_i) * maag(VDJ_DIV_JOI_INS_I, 0, col_i, ins_col_i);
600  }
601  }
602  }
603 
604  // VD insertions
605  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++row_i) {
606  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++col_i) {
607  for (dim_t d_col_i = 0; d_col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++d_col_i) {
608  (*_backward_acc)(VDJ_VAR_DIV_INS_I, 0, row_i, col_i) +=
609  (*_backward_acc)(VDJ_DIV_DEL_I, 0, col_i, d_col_i) * maag(VDJ_DIV_DEL_I, d_ind, col_i, d_col_i);
610  }
611  }
612  }
613 
614  // V deletions and V genes
615  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
616  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DEL_I); ++col_i) {
617  for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++ins_col_i) {
618  (*_backward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) +=
619  (*_backward_acc)(VDJ_VAR_DIV_INS_I, 0, col_i, ins_col_i) * maag(VDJ_VAR_DIV_INS_I, 0, col_i, ins_col_i);
620  }
621  (*_backward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) +=
622  (*_backward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) * maag(VDJ_VAR_DEL_I, v_ind, 0, col_i);
623  }
624 
625  // update the full (back) generation probability
626  _back_full_prob += (*_backward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) * maag(VDJ_VAR_GEN_I, v_ind, 0, 0);
627  }
628  }
629 
630 
631  void forward_backward_vdj(const MAAG &maag) {
632  _forward_acc.reset(new ProbMMC());
633  _forward_acc->resize(maag.chainSize());
634  // V genes - fi which is constant for all J-D pairs
635  _forward_acc->initNode(VDJ_VAR_GEN_I, maag.nodeSize(VDJ_VAR_GEN_I), 1, 1);
636  // V deletions - fi which is constant for all J-D pairs
637  _forward_acc->initNode(VDJ_VAR_DEL_I, maag.nodeSize(VDJ_VAR_DEL_I), 1, maag.nodeColumns(VDJ_VAR_DEL_I));
638  // VD insertions
639  _forward_acc->initNode(VDJ_VAR_DIV_INS_I, 1, maag.nodeRows(VDJ_VAR_DIV_INS_I), maag.nodeColumns(VDJ_VAR_DIV_INS_I));
640  // D5' - 3' deletions
641  _forward_acc->initNode(VDJ_DIV_DEL_I, 1, maag.nodeRows(VDJ_DIV_DEL_I), maag.nodeColumns(VDJ_DIV_DEL_I));
642  // DJ insertions
643  _forward_acc->initNode(VDJ_DIV_JOI_INS_I, 1, maag.nodeRows(VDJ_DIV_JOI_INS_I), maag.nodeColumns(VDJ_DIV_JOI_INS_I));
644  // J deletions
645  _forward_acc->initNode(VDJ_JOI_DEL_I, 1, maag.nodeRows(VDJ_JOI_DEL_I), maag.nodeColumns(VDJ_JOI_DEL_I));
646  // J-D pairs
647 // _forward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, maag.nodeRows(VDJ_JOI_DIV_GEN_I), 1);
648  _forward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, 1, 1);
649 
650  _backward_acc.reset(new ProbMMC());
651  _backward_acc->resize(maag.chainSize());
652  // V genes
653  _backward_acc->initNode(VDJ_VAR_GEN_I, maag.nodeSize(VDJ_VAR_GEN_I), 1, 1);
654  // V deletions
655  _backward_acc->initNode(VDJ_VAR_DEL_I, maag.nodeSize(VDJ_VAR_DEL_I), 1, maag.nodeColumns(VDJ_VAR_DEL_I));
656  // VD insertions
657  _backward_acc->initNode(VDJ_VAR_DIV_INS_I, 1, maag.nodeRows(VDJ_VAR_DIV_INS_I), maag.nodeColumns(VDJ_VAR_DIV_INS_I));
658  // D5' - 3' deletions
659  _backward_acc->initNode(VDJ_DIV_DEL_I, 1, maag.nodeRows(VDJ_DIV_DEL_I), maag.nodeColumns(VDJ_DIV_DEL_I));
660  // DJ insertions
661  _backward_acc->initNode(VDJ_DIV_JOI_INS_I, 1, maag.nodeRows(VDJ_DIV_JOI_INS_I), maag.nodeColumns(VDJ_DIV_JOI_INS_I));
662  // J deletions
663  _backward_acc->initNode(VDJ_JOI_DEL_I, 1, maag.nodeRows(VDJ_JOI_DEL_I), maag.nodeColumns(VDJ_JOI_DEL_I));
664  // J-D pairs
665 // _backward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, maag.nodeRows(VDJ_JOI_DIV_GEN_I), 1);
666  _backward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, 1, 1);
667 
668  // Because fi for V genes, V deletions and VD insertions are constant for all
669  // pairs of J-D, we compute them here.
670  this->fillZero(_forward_acc.get());
671  // V genes and deletions
672  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
673  // gene probability
674  (*_forward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) = maag(VDJ_VAR_GEN_I, v_ind, 0, 0);
675 
676  // deletions probabilities
677  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DEL_I); ++col_i) {
678  (*_forward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) =
679  (*_forward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) * maag(VDJ_VAR_DEL_I, v_ind, 0, col_i);
680  }
681  }
682  // VD insertions
683  for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++row_i) {
684  for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++col_i) {
685  for (event_ind_t v_ind = 0; v_ind < maag.nVar(); ++v_ind) {
686  (*_forward_acc)(VDJ_VAR_DIV_INS_I, 0, row_i, col_i) +=
687  (*_forward_acc)(VDJ_VAR_DEL_I, v_ind, 0, row_i) * maag(VDJ_VAR_DIV_INS_I, 0, row_i, col_i);
688  }
689  }
690  }
691 
692  // Compute fi * bi / Pgen for each event.
693  bool recompute_d_gen_fi = true;
694  for (event_ind_t d_ind = 0; d_ind < maag.nDiv(); ++d_ind) {
695  recompute_d_gen_fi = true;
696  for (event_ind_t j_ind = 0; j_ind < maag.nJoi(); ++j_ind) {
697  // compute forward and backward probabilities for a specific J gene
698  this->forward_vdj(maag, d_ind, j_ind, recompute_d_gen_fi);
699  this->backward_vdj(maag, d_ind, j_ind);
700  recompute_d_gen_fi = false;
701 
702  // add fi * bi to the accumulator
703  for (matrix_ind_t mat_i = 0; mat_i < maag.nVar(); ++mat_i) {
704  this->pushEventPairs(maag, VDJ_VAR_GEN_I, mat_i, mat_i);
705  this->pushEventPairsWithErrors(maag, VDJ_VAR_DEL_I, mat_i, mat_i, 0);
706  }
707  this->pushEventPairs(maag, VDJ_VAR_DIV_INS_I, 0, 0);
708  this->pushEventPairsWithErrors(maag, VDJ_DIV_DEL_I, d_ind, 0, 1);
709  this->pushEventPairs(maag, VDJ_DIV_JOI_INS_I, 0, 0);
710  this->pushEventPairsWithErrors(maag, VDJ_JOI_DEL_I, j_ind, 0, 2);
711  this->pushEventPair(maag, VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind, 0, 0, 0);
712 
713  seq_len_t v_vertices = maag.nodeColumns(VDJ_VAR_DEL_I),
714  d3_vertices = maag.nodeRows(VDJ_DIV_DEL_I),
715  d5_vertices = maag.nodeColumns(VDJ_DIV_DEL_I),
716  j_vertices = maag.nodeRows(VDJ_JOI_DEL_I);
717 
718  this->inferInsertionNucleotides(maag, VDJ_VAR_DIV_INS_I,
719  0, v_vertices - 1,
720  v_vertices, v_vertices + d3_vertices - 1,
721  _nuc_arr1, false);
722 
723 
724  this->inferInsertionNucleotides(maag, VDJ_DIV_JOI_INS_I,
725  v_vertices + d3_vertices, v_vertices + d3_vertices + d5_vertices - 1,
726  v_vertices + d3_vertices + d5_vertices, v_vertices + d3_vertices + d5_vertices + j_vertices - 1,
727  _nuc_arr2, true);
728  }
729  }
730 
731  this->vectorise_pair_map(maag);
732  }
733 
734  };
735 
736 
738  public:
739 
740 
742 
743 
744  virtual ~ForwardBackwardAlgorithm() { }
745 
746 
747  virtual void infer(const MAAG &maag) = 0;
748 
749 
750  event_pair_t nextEvent() {
751  if (_status) {
752  event_pair_t res = _pairs[_pairs_i];
753 // cout << res.second << " -> ";
754 // res.second = res.second / _full_prob;
755 // cout << res.second << endl;
756  ++_pairs_i;
757  return res;
758  }
759 
760  _status = false;
761  return event_pair_t(0, 0);
762  }
763 
764 
765  bool is_empty() {
766 // return _status && _pairs_i != _pairs.size();
767  if (_status) {
768  if (_pairs_i != _pairs.size()) {
769  return false;
770  }
771  }
772  return true;
773  }
774 
775 
776  bool status() const { return _status; }
777 
778 
779  prob_t fullProbability() const { return _full_prob; }
780 
781 
782  prob_t bfullProbability() const { return _back_full_prob; }
783 
784 
785  const vector<event_pair_t>& event_pairs() const { return _pairs; }
786 
787 
788  prob_t* insertion_probs() const;
789 
790 
791  protected:
792  pProbMMC _forward_acc, _backward_acc;
793  bool _status;
794  vector<event_pair_t> _pairs;
795  size_t _pairs_i;
796  prob_t _full_prob, _back_full_prob;
797 
798  };
799 
800 
802 
803  };
804 
805 
807 
808  };
809 
810 
812 
813  };
814 
815 
816 }
817 
818 #endif //YMIR_MAAGFORWARDBACKWARDALGORITHM_H
prob_t _full_prob
Definition: maagforwardbackwardalgorithm.h:140
Definition: aligner.h:37
Definition: maagforwardbackwardalgorithm.h:806
Definition: maagforwardbackwardalgorithm.h:811
prob_t _back_full_prob
Definition: maagforwardbackwardalgorithm.h:141
Definition: maagforwardbackwardalgorithm.h:801
bool _status
Definition: maagforwardbackwardalgorithm.h:793
Definition: maagforwardbackwardalgorithm.h:20
Multi-Alignment Assembly Graph - basic class for representing all possible generation scenarios of a ...
Definition: maag.h:46
uint8_t node_ind_t
Node index type.
Definition: multimatrixchain.h:91
seq_len_t dim_t
Type of dimensions of matrices (rows and columns).
Definition: multimatrixchain.h:108
Definition: maagforwardbackwardalgorithm.h:737
void pushEventValue(event_ind_t event_index, prob_t prob_value)
Access to a hash map which maps event probabilities to event indices.
Definition: maagforwardbackwardalgorithm.h:305
event_ind_t nVar() const
Get the number of aligned gene segments.
Definition: maag.h:334
Class for storing lists of matrices, where one node in the list (called "chain") could contain more t...
Definition: multimatrixchain.h:39
uint8_t matrix_ind_t
Matrix index type.
Definition: multimatrixchain.h:99
vector< event_pair_t > _pairs
Definition: maagforwardbackwardalgorithm.h:143