5 #ifndef YMIR_MAAGFORWARDBACKWARDALGORITHM_H 6 #define YMIR_MAAGFORWARDBACKWARDALGORITHM_H 8 #include <unordered_map> 14 class MAAGForwardBackwardAlgorithm;
15 class ForwardBackwardAlgorithm;
16 class VJForwardBackward;
17 class VDJForwardBackward;
41 process(maag, error_mode);
53 bool process(
const MAAG &maag, ErrorMode error_mode = NO_ERRORS) {
58 _pairs.reserve(maag._chain.size() + 10);
62 _err_mode = error_mode;
64 if (_err_mode == COMPUTE_ERRORS && !maag.has_errors()) {
65 cerr <<
"MAAG forward-backward algorithm error: no error matrix has been found in the input MAAG." << endl;
69 if (maag.has_events()) {
71 if (maag.recombination() == VJ_RECOMB) {
72 fill(_nuc_arr1, _nuc_arr1 + 4, 0);
73 this->forward_backward_vj(maag);
74 }
else if (maag.recombination() == VDJ_RECOMB) {
75 fill(_nuc_arr1, _nuc_arr1 + 16, 0);
76 fill(_nuc_arr2, _nuc_arr2 + 16, 0);
77 this->forward_backward_vdj(maag);
79 cerr <<
"MAAG forward-backward algorithm error: unknown recombination type." << endl;
83 cerr <<
"MAAG forward-backward algorithm error: no event matrix has been found in the input MAAG." << endl;
92 event_pair_t nextEvent() {
94 event_pair_t res =
_pairs[_pairs_i];
103 return event_pair_t(0, 0);
109 if (_pairs_i !=
_pairs.size()) {
117 bool status()
const {
return _status; }
120 prob_t fullProbability()
const {
return _full_prob; }
126 const vector<event_pair_t>& event_pairs()
const {
return _pairs; }
129 prob_t* VJ_nuc_probs() {
return _nuc_arr1; }
130 prob_t* VD_nuc_probs() {
return _nuc_arr1; }
131 prob_t* DJ_nuc_probs() {
return _nuc_arr2; }
133 prob_t err_prob() {
return _err_prob; }
139 pProbMMC _forward_acc, _backward_acc;
145 unordered_map<event_ind_t, prob_t> _pair_map;
146 prob_t _nuc_arr1[16], _nuc_arr2[16];
154 void fillZero(
ProbMMC *mmc, uint start_node = 0) {
155 for (uint i = start_node; i < mmc->chainSize(); ++i) {
156 for (uint j = 0; j < mmc->nodeSize(i); ++j) {
166 void inferInsertionNucleotides(
const MAAG &maag, node_ind_t ins_node,
167 seq_len_t left_start_pos, seq_len_t left_end_pos,
168 seq_len_t right_start_pos, seq_len_t right_end_pos,
169 prob_t *nuc_arr,
bool reversed =
false)
171 node_ind_t forw_node = ins_node - 1, back_node = ins_node;
172 prob_t scenario_prob = 0;
175 seq_len_t left_pos, right_pos, start_shift;
183 for (dim_t row_i = 0; row_i < maag.nodeRows(ins_node); ++row_i) {
184 left_pos = left_start_pos + row_i;
185 for (dim_t col_i = 0; col_i < maag.nodeColumns(ins_node); ++col_i) {
186 right_pos = right_start_pos + col_i;
188 if (maag.position(right_pos) - maag.position(left_pos) - 1 > 0 && maag.event_index(ins_node, 0, row_i, col_i)) {
189 fill(temp_arr, temp_arr + 4, 0);
192 scenario_prob = (*_forward_acc)(ins_node, 0, row_i, col_i)
193 * (*_backward_acc)(back_node, 0, row_i, col_i);
195 if (_err_mode == NO_ERRORS) {
196 for (seq_len_t pos = maag.position(left_pos) + 1; pos < maag.position(right_pos); ++pos) {
197 temp_arr[nuc_hash(maag.sequence()[pos - 1])] += 1;
201 for (seq_len_t pos = maag.position(left_pos) + 1; pos < maag.position(right_pos); ++pos) {
202 temp_arr[0] += _err_prob / 3;
203 temp_arr[1] += _err_prob / 3;
204 temp_arr[2] += _err_prob / 3;
205 temp_arr[3] += _err_prob / 3;
206 temp_arr[nuc_hash(maag.sequence()[pos - 1])] += (1 - _err_prob / 3);
210 _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
214 for (
auto i = 0; i < 4; ++i) {
216 nuc_arr[i] += (temp_arr[i] * scenario_prob) / n;
224 for (dim_t row_i = 0; row_i < maag.nodeRows(ins_node); ++row_i) {
225 left_pos = left_start_pos + row_i;
226 for (dim_t col_i = 0; col_i < maag.nodeColumns(ins_node); ++col_i) {
227 right_pos = right_start_pos + col_i;
229 if (maag.position(right_pos) - maag.position(left_pos) - 1 > 0) {
230 fill(temp_arr, temp_arr + 16, 0);
234 scenario_prob = (*_forward_acc)(ins_node, 0, row_i, col_i)
235 * (*_backward_acc)(back_node, 0, row_i, col_i);
238 if (maag.position(left_pos) == 0) {
240 temp_arr[4 * nuc_hash(
'A') + nuc_hash(maag.sequence()[0])] = .25;
241 temp_arr[4 * nuc_hash(
'C') + nuc_hash(maag.sequence()[0])] = .25;
242 temp_arr[4 * nuc_hash(
'G') + nuc_hash(maag.sequence()[0])] = .25;
243 temp_arr[4 * nuc_hash(
'T') + nuc_hash(maag.sequence()[0])] = .25;
247 if (_err_mode == NO_ERRORS) {
248 for (seq_len_t pos = maag.position(left_pos) + start_shift + 1; pos < maag.position(right_pos); ++pos) {
249 temp_arr[4 * nuc_hash(maag.sequence()[pos - 2]) + nuc_hash(maag.sequence()[pos - 1])] += 1;
253 for (seq_len_t pos = maag.position(left_pos) + start_shift + 1; pos < maag.position(right_pos); ++pos) {
254 for (
int i = 0; i < 16; ++i) {
255 temp_arr[0] += _err_prob * _err_prob / 15;
257 temp_arr[4 * nuc_hash(maag.sequence()[pos - 2]) + nuc_hash(maag.sequence()[pos - 1])] += (1 - _err_prob * _err_prob / 15);
261 _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
264 if (maag.position(right_pos) == maag.sequence().size() + 1) {
266 temp_arr[4 * nuc_hash(
'A') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
267 temp_arr[4 * nuc_hash(
'C') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
268 temp_arr[4 * nuc_hash(
'G') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
269 temp_arr[4 * nuc_hash(
'T') + nuc_hash(maag.sequence()[maag.sequence().size() - 1])] = .25;
273 if (_err_mode == NO_ERRORS) {
274 for (seq_len_t pos = maag.position(right_pos) - start_shift; pos > maag.position(left_pos) + 1; --pos) {
275 temp_arr[4 * nuc_hash(maag.sequence()[pos - 1]) + nuc_hash(maag.sequence()[pos - 2])] += 1;
279 for (seq_len_t pos = maag.position(right_pos) - start_shift; pos > maag.position(left_pos) + 1; --pos) {
280 for (
int i = 0; i < 16; ++i) {
281 temp_arr[0] += _err_prob * _err_prob / 15;
283 temp_arr[4 * nuc_hash(maag.sequence()[pos - 1]) + nuc_hash(maag.sequence()[pos - 2])] += (1 - _err_prob * _err_prob / 15);
287 _err_prob += scenario_prob * (maag.position(right_pos) - maag.position(left_pos) - 1) / maag.n_poses();
291 for (
auto i = 0; i < 16; ++i) {
292 nuc_arr[i] += (temp_arr[i] * scenario_prob) / n;
306 if (prob_value && event_index) {
308 auto elem = _pair_map.find(event_index);
310 if (elem == _pair_map.end()) {
311 _pair_map[event_index] = 0;
313 _pair_map[event_index] += prob_value;
317 void pushEventPair(
const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, dim_t maag_row_i, dim_t maag_col_i,
318 matrix_ind_t fb_mat_i, dim_t fb_row_i, dim_t fb_col_i) {
319 this->
pushEventValue(maag.event_index(node_i, maag_mat_i, maag_row_i, maag_col_i),
320 (*_forward_acc)(node_i, fb_mat_i, fb_row_i, fb_col_i) * (*_backward_acc)(node_i, fb_mat_i, fb_row_i, fb_col_i));
323 void pushEventPairs(
const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, matrix_ind_t fb_mat_i) {
324 for (dim_t row_i = 0; row_i < maag.nodeRows(node_i); ++row_i) {
325 for (dim_t col_i = 0; col_i < maag.nodeColumns(node_i); ++col_i) {
326 this->pushEventPair(maag, node_i, maag_mat_i, row_i, col_i, fb_mat_i, row_i, col_i);
331 void pushEventPairsWithErrors(
const MAAG &maag, node_ind_t node_i, matrix_ind_t maag_mat_i, matrix_ind_t fb_mat_i, node_ind_t err_node_i) {
332 this->pushEventPairs(maag, node_i, maag_mat_i, fb_mat_i);
334 if (_err_mode == COMPUTE_ERRORS) {
335 for (dim_t row_i = 0; row_i < maag.nodeRows(node_i); ++row_i) {
336 for (dim_t col_i = 0; col_i < maag.nodeColumns(node_i); ++col_i) {
337 if (maag.errors(err_node_i, fb_mat_i, row_i, col_i)) {
338 _err_prob += (*_forward_acc)(node_i, fb_mat_i, row_i, col_i)
339 * (*_backward_acc)(node_i, fb_mat_i, row_i, col_i)
340 * (maag.errors(err_node_i, fb_mat_i, row_i, col_i) / maag.n_poses());
352 void vectorise_pair_map(
const MAAG &maag) {
359 for (
int i = 0; i < 16; ++i) {
362 for (
int i = 0; i < 16; ++i) {
367 _pairs.reserve(_pair_map.size() + 40);
368 for (
auto it = _pair_map.begin(); it != _pair_map.end(); ++it) {
369 _pairs.push_back(event_pair_t(it->first, it->second / _full_prob));
380 void forward_vj(
const MAAG &maag, event_ind_t j_ind) {
381 this->fillZero(_forward_acc.get());
384 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_GEN_I); ++row_i) {
385 (*_forward_acc)(VJ_VAR_JOI_GEN_I, 0, row_i, 0) = maag(VJ_VAR_JOI_GEN_I, 0, row_i, j_ind);
389 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
390 for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
391 (*_forward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) =
392 (*_forward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) * maag(VJ_VAR_DEL_I, v_ind, 0, col_i);
397 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_i) {
398 prob_t temp_prob = 0;
400 for (dim_t v_i = 0; v_i < maag.nodeSize(VJ_VAR_DEL_I); ++v_i) {
401 temp_prob += (*_forward_acc)(VJ_VAR_DEL_I, v_i, 0, row_i);
404 for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++col_i) {
405 (*_forward_acc)(VJ_VAR_JOI_INS_I, 0, row_i, col_i) =
406 temp_prob * maag(VJ_VAR_JOI_INS_I, 0, row_i, col_i);
411 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
412 for (dim_t row_vj_i = 0; row_vj_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_vj_i) {
413 (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) += (*_forward_acc)(VJ_VAR_JOI_INS_I, 0, row_vj_i, row_i);
415 (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) *= maag(VJ_JOI_DEL_I, j_ind, row_i, 0);
420 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
421 _full_prob += (*_forward_acc)(VJ_JOI_DEL_I, 0, row_i, 0);
427 void backward_vj(
const MAAG &maag, event_ind_t j_ind) {
428 this->fillZero(_backward_acc.get());
431 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_JOI_DEL_I); ++row_i) {
432 (*_backward_acc)(VJ_JOI_DEL_I, 0, row_i, 0) = 1;
436 for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++col_i) {
437 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_INS_I); ++row_i) {
438 (*_backward_acc)(VJ_VAR_JOI_INS_I, 0, row_i, col_i) += maag(VJ_JOI_DEL_I, j_ind, col_i, 0);
443 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
444 for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
445 for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VJ_VAR_JOI_INS_I); ++ins_col_i) {
446 (*_backward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) +=
447 (*_backward_acc)(VJ_VAR_JOI_INS_I, 0, col_i, ins_col_i) * maag(VJ_VAR_JOI_INS_I, 0, col_i, ins_col_i);
453 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
454 for (dim_t col_i = 0; col_i < maag.nodeColumns(VJ_VAR_DEL_I); ++col_i) {
455 (*_backward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) +=
456 (*_backward_acc)(VJ_VAR_DEL_I, v_ind, 0, col_i) * maag(VJ_VAR_DEL_I, v_ind, 0, col_i);
460 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
462 (*_backward_acc)(VJ_VAR_JOI_GEN_I, 0, v_ind, 0) * maag(VJ_VAR_JOI_GEN_I, 0, v_ind, j_ind);
468 void forward_backward_vj(
const MAAG &maag) {
469 _forward_acc.reset(
new ProbMMC());
470 _forward_acc->resize(maag.chainSize());
472 _forward_acc->initNode(VJ_VAR_JOI_GEN_I, 1, maag.nodeRows(VJ_VAR_JOI_GEN_I), 1);
474 _forward_acc->initNode(VJ_VAR_DEL_I, maag.nodeSize(VJ_VAR_DEL_I), 1, maag.nodeColumns(VJ_VAR_DEL_I));
476 _forward_acc->initNode(VJ_VAR_JOI_INS_I, 1, maag.nodeRows(VJ_VAR_JOI_INS_I), maag.nodeColumns(VJ_VAR_JOI_INS_I));
478 _forward_acc->initNode(VJ_JOI_DEL_I, 1, maag.nodeRows(VJ_JOI_DEL_I), 1);
480 _backward_acc.reset(
new ProbMMC());
481 _backward_acc->resize(maag.chainSize());
483 _backward_acc->initNode(VJ_VAR_JOI_GEN_I, 1, maag.nodeRows(VJ_VAR_JOI_GEN_I), 1);
485 _backward_acc->initNode(VJ_VAR_DEL_I, maag.nodeSize(VJ_VAR_DEL_I), 1, maag.nodeColumns(VJ_VAR_DEL_I));
487 _backward_acc->initNode(VJ_VAR_JOI_INS_I, 1, maag.nodeRows(VJ_VAR_JOI_INS_I), maag.nodeColumns(VJ_VAR_JOI_INS_I));
489 _backward_acc->initNode(VJ_JOI_DEL_I, 1, maag.nodeRows(VJ_JOI_DEL_I), 1);
492 for (event_ind_t j_ind = 0; j_ind < maag.nJoi(); ++j_ind) {
494 this->forward_vj(maag, j_ind);
495 this->backward_vj(maag, j_ind);
498 for (dim_t row_i = 0; row_i < maag.nodeRows(VJ_VAR_JOI_GEN_I); ++row_i) {
499 this->pushEventPair(maag, VJ_VAR_JOI_GEN_I, 0, row_i, j_ind, 0, row_i, 0);
501 for (matrix_ind_t mat_i = 0; mat_i < maag.
nVar(); ++mat_i) {
502 this->pushEventPairsWithErrors(maag, VJ_VAR_DEL_I, mat_i, mat_i, 0);
504 this->pushEventPairs(maag, VJ_VAR_JOI_INS_I, 0, 0);
505 this->pushEventPairsWithErrors(maag, VJ_JOI_DEL_I, j_ind, 0, 1);
507 this->inferInsertionNucleotides(maag, VJ_VAR_JOI_INS_I,
508 0, maag.nodeColumns(VJ_VAR_DEL_I) - 1,
509 maag.nodeColumns(VJ_VAR_DEL_I), maag.nodeColumns(VJ_VAR_DEL_I) + maag.nodeRows(VJ_JOI_DEL_I) - 1,
513 this->vectorise_pair_map(maag);
522 void forward_vdj(
const MAAG &maag, event_ind_t d_ind, event_ind_t j_ind,
bool recompute_d_gen_fi) {
523 if (recompute_d_gen_fi) {
529 this->fillZero(_forward_acc.get(), VDJ_DIV_DEL_I);
532 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++row_i) {
533 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++col_i) {
534 for (dim_t ins_row_i = 0; ins_row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++ins_row_i) {
535 (*_forward_acc)(VDJ_DIV_DEL_I, 0, row_i, col_i) +=
536 (*_forward_acc)(VDJ_VAR_DIV_INS_I, 0, ins_row_i, row_i) * maag(VDJ_DIV_DEL_I, d_ind, row_i, col_i);
542 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++row_i) {
543 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++col_i) {
544 for (dim_t dgen_row_i = 0; dgen_row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++dgen_row_i) {
545 (*_forward_acc)(VDJ_DIV_JOI_INS_I, 0, row_i, col_i) +=
546 (*_forward_acc)(VDJ_DIV_DEL_I, 0, dgen_row_i, row_i) * maag(VDJ_DIV_JOI_INS_I, 0, row_i, col_i);
552 this->fillZero(_forward_acc.get(), VDJ_JOI_DEL_I);
556 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
557 for (dim_t ins_row_i = 0; ins_row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++ins_row_i) {
558 (*_forward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) +=
559 (*_forward_acc)(VDJ_DIV_JOI_INS_I, 0, ins_row_i, row_i) * maag(VDJ_JOI_DEL_I, j_ind, row_i, 0);
564 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
565 (*_forward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0) +=
566 (*_forward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) * maag(VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind);
570 _full_prob += (*_forward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0);
575 void backward_vdj(
const MAAG &maag, event_ind_t d_ind, event_ind_t j_ind) {
576 this->fillZero(_backward_acc.get());
579 (*_backward_acc)(VDJ_JOI_DIV_GEN_I, 0, 0, 0) = 1;
582 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_JOI_DEL_I); ++row_i) {
583 (*_backward_acc)(VDJ_JOI_DEL_I, 0, row_i, 0) += maag(VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind);
587 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_JOI_INS_I); ++row_i) {
588 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++col_i) {
589 (*_backward_acc)(VDJ_DIV_JOI_INS_I, 0, row_i, col_i) +=
590 (*_backward_acc)(VDJ_JOI_DEL_I, 0, col_i, 0) * maag(VDJ_JOI_DEL_I, j_ind, col_i, 0);
595 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_DIV_DEL_I); ++row_i) {
596 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++col_i) {
597 for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VDJ_DIV_JOI_INS_I); ++ins_col_i) {
598 (*_backward_acc)(VDJ_DIV_DEL_I, 0, row_i, col_i) +=
599 (*_backward_acc)(VDJ_DIV_JOI_INS_I, 0, col_i, ins_col_i) * maag(VDJ_DIV_JOI_INS_I, 0, col_i, ins_col_i);
605 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++row_i) {
606 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++col_i) {
607 for (dim_t d_col_i = 0; d_col_i < maag.nodeColumns(VDJ_DIV_DEL_I); ++d_col_i) {
608 (*_backward_acc)(VDJ_VAR_DIV_INS_I, 0, row_i, col_i) +=
609 (*_backward_acc)(VDJ_DIV_DEL_I, 0, col_i, d_col_i) * maag(VDJ_DIV_DEL_I, d_ind, col_i, d_col_i);
615 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
616 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DEL_I); ++col_i) {
617 for (dim_t ins_col_i = 0; ins_col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++ins_col_i) {
618 (*_backward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) +=
619 (*_backward_acc)(VDJ_VAR_DIV_INS_I, 0, col_i, ins_col_i) * maag(VDJ_VAR_DIV_INS_I, 0, col_i, ins_col_i);
621 (*_backward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) +=
622 (*_backward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) * maag(VDJ_VAR_DEL_I, v_ind, 0, col_i);
626 _back_full_prob += (*_backward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) * maag(VDJ_VAR_GEN_I, v_ind, 0, 0);
631 void forward_backward_vdj(
const MAAG &maag) {
632 _forward_acc.reset(
new ProbMMC());
633 _forward_acc->resize(maag.chainSize());
635 _forward_acc->initNode(VDJ_VAR_GEN_I, maag.nodeSize(VDJ_VAR_GEN_I), 1, 1);
637 _forward_acc->initNode(VDJ_VAR_DEL_I, maag.nodeSize(VDJ_VAR_DEL_I), 1, maag.nodeColumns(VDJ_VAR_DEL_I));
639 _forward_acc->initNode(VDJ_VAR_DIV_INS_I, 1, maag.nodeRows(VDJ_VAR_DIV_INS_I), maag.nodeColumns(VDJ_VAR_DIV_INS_I));
641 _forward_acc->initNode(VDJ_DIV_DEL_I, 1, maag.nodeRows(VDJ_DIV_DEL_I), maag.nodeColumns(VDJ_DIV_DEL_I));
643 _forward_acc->initNode(VDJ_DIV_JOI_INS_I, 1, maag.nodeRows(VDJ_DIV_JOI_INS_I), maag.nodeColumns(VDJ_DIV_JOI_INS_I));
645 _forward_acc->initNode(VDJ_JOI_DEL_I, 1, maag.nodeRows(VDJ_JOI_DEL_I), maag.nodeColumns(VDJ_JOI_DEL_I));
648 _forward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, 1, 1);
650 _backward_acc.reset(
new ProbMMC());
651 _backward_acc->resize(maag.chainSize());
653 _backward_acc->initNode(VDJ_VAR_GEN_I, maag.nodeSize(VDJ_VAR_GEN_I), 1, 1);
655 _backward_acc->initNode(VDJ_VAR_DEL_I, maag.nodeSize(VDJ_VAR_DEL_I), 1, maag.nodeColumns(VDJ_VAR_DEL_I));
657 _backward_acc->initNode(VDJ_VAR_DIV_INS_I, 1, maag.nodeRows(VDJ_VAR_DIV_INS_I), maag.nodeColumns(VDJ_VAR_DIV_INS_I));
659 _backward_acc->initNode(VDJ_DIV_DEL_I, 1, maag.nodeRows(VDJ_DIV_DEL_I), maag.nodeColumns(VDJ_DIV_DEL_I));
661 _backward_acc->initNode(VDJ_DIV_JOI_INS_I, 1, maag.nodeRows(VDJ_DIV_JOI_INS_I), maag.nodeColumns(VDJ_DIV_JOI_INS_I));
663 _backward_acc->initNode(VDJ_JOI_DEL_I, 1, maag.nodeRows(VDJ_JOI_DEL_I), maag.nodeColumns(VDJ_JOI_DEL_I));
666 _backward_acc->initNode(VDJ_JOI_DIV_GEN_I, 1, 1, 1);
670 this->fillZero(_forward_acc.get());
672 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
674 (*_forward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) = maag(VDJ_VAR_GEN_I, v_ind, 0, 0);
677 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DEL_I); ++col_i) {
678 (*_forward_acc)(VDJ_VAR_DEL_I, v_ind, 0, col_i) =
679 (*_forward_acc)(VDJ_VAR_GEN_I, v_ind, 0, 0) * maag(VDJ_VAR_DEL_I, v_ind, 0, col_i);
683 for (dim_t row_i = 0; row_i < maag.nodeRows(VDJ_VAR_DIV_INS_I); ++row_i) {
684 for (dim_t col_i = 0; col_i < maag.nodeColumns(VDJ_VAR_DIV_INS_I); ++col_i) {
685 for (event_ind_t v_ind = 0; v_ind < maag.
nVar(); ++v_ind) {
686 (*_forward_acc)(VDJ_VAR_DIV_INS_I, 0, row_i, col_i) +=
687 (*_forward_acc)(VDJ_VAR_DEL_I, v_ind, 0, row_i) * maag(VDJ_VAR_DIV_INS_I, 0, row_i, col_i);
693 bool recompute_d_gen_fi =
true;
694 for (event_ind_t d_ind = 0; d_ind < maag.nDiv(); ++d_ind) {
695 recompute_d_gen_fi =
true;
696 for (event_ind_t j_ind = 0; j_ind < maag.nJoi(); ++j_ind) {
698 this->forward_vdj(maag, d_ind, j_ind, recompute_d_gen_fi);
699 this->backward_vdj(maag, d_ind, j_ind);
700 recompute_d_gen_fi =
false;
703 for (matrix_ind_t mat_i = 0; mat_i < maag.
nVar(); ++mat_i) {
704 this->pushEventPairs(maag, VDJ_VAR_GEN_I, mat_i, mat_i);
705 this->pushEventPairsWithErrors(maag, VDJ_VAR_DEL_I, mat_i, mat_i, 0);
707 this->pushEventPairs(maag, VDJ_VAR_DIV_INS_I, 0, 0);
708 this->pushEventPairsWithErrors(maag, VDJ_DIV_DEL_I, d_ind, 0, 1);
709 this->pushEventPairs(maag, VDJ_DIV_JOI_INS_I, 0, 0);
710 this->pushEventPairsWithErrors(maag, VDJ_JOI_DEL_I, j_ind, 0, 2);
711 this->pushEventPair(maag, VDJ_JOI_DIV_GEN_I, 0, j_ind, d_ind, 0, 0, 0);
713 seq_len_t v_vertices = maag.nodeColumns(VDJ_VAR_DEL_I),
714 d3_vertices = maag.nodeRows(VDJ_DIV_DEL_I),
715 d5_vertices = maag.nodeColumns(VDJ_DIV_DEL_I),
716 j_vertices = maag.nodeRows(VDJ_JOI_DEL_I);
718 this->inferInsertionNucleotides(maag, VDJ_VAR_DIV_INS_I,
720 v_vertices, v_vertices + d3_vertices - 1,
724 this->inferInsertionNucleotides(maag, VDJ_DIV_JOI_INS_I,
725 v_vertices + d3_vertices, v_vertices + d3_vertices + d5_vertices - 1,
726 v_vertices + d3_vertices + d5_vertices, v_vertices + d3_vertices + d5_vertices + j_vertices - 1,
731 this->vectorise_pair_map(maag);
747 virtual void infer(
const MAAG &maag) = 0;
750 event_pair_t nextEvent() {
752 event_pair_t res =
_pairs[_pairs_i];
761 return event_pair_t(0, 0);
768 if (_pairs_i !=
_pairs.size()) {
776 bool status()
const {
return _status; }
779 prob_t fullProbability()
const {
return _full_prob; }
785 const vector<event_pair_t>& event_pairs()
const {
return _pairs; }
788 prob_t* insertion_probs()
const;
792 pProbMMC _forward_acc, _backward_acc;
794 vector<event_pair_t>
_pairs;
818 #endif //YMIR_MAAGFORWARDBACKWARDALGORITHM_H prob_t _full_prob
Definition: maagforwardbackwardalgorithm.h:140
Definition: maagforwardbackwardalgorithm.h:806
Definition: maagforwardbackwardalgorithm.h:811
prob_t _back_full_prob
Definition: maagforwardbackwardalgorithm.h:141
Definition: maagforwardbackwardalgorithm.h:801
bool _status
Definition: maagforwardbackwardalgorithm.h:793
Definition: maagforwardbackwardalgorithm.h:20
Multi-Alignment Assembly Graph - basic class for representing all possible generation scenarios of a ...
Definition: maag.h:46
uint8_t node_ind_t
Node index type.
Definition: multimatrixchain.h:91
seq_len_t dim_t
Type of dimensions of matrices (rows and columns).
Definition: multimatrixchain.h:108
Definition: maagforwardbackwardalgorithm.h:737
void pushEventValue(event_ind_t event_index, prob_t prob_value)
Access to a hash map which maps event probabilities to event indices.
Definition: maagforwardbackwardalgorithm.h:305
event_ind_t nVar() const
Get the number of aligned gene segments.
Definition: maag.h:334
Class for storing lists of matrices, where one node in the list (called "chain") could contain more t...
Definition: multimatrixchain.h:39
uint8_t matrix_ind_t
Matrix index type.
Definition: multimatrixchain.h:99
vector< event_pair_t > _pairs
Definition: maagforwardbackwardalgorithm.h:143