From a20d7ba384916c3b6df66bda4095b992d34c5a8b Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Fri, 11 Sep 2020 06:12:14 +0300 Subject: [PATCH] Added callback to disable PriorBox to PriorBoxIE transformation (#2159) * move priorbox to ie transformations to Opset1ToLegacyOpset pipeline * fix typo * Revert "fix typo" This reverts commit 4077a78cbdcc1340b0999a3756ce348dd567e45f. * Revert "move priorbox to ie transformations to Opset1ToLegacyOpset pipeline" This reverts commit 910e41ff2099b0516fc9bff2baaa2ec678c8a7bf. * add functionality to disable prior box to ie transformation * fix callback --- .../convert_prior_to_ie_prior.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.cpp index b881df50b36..659c85de618 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.cpp @@ -33,14 +33,14 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box() { auto prior_box = std::make_shared(data, image, attr); auto unsqueeze = std::make_shared (prior_box, axes); - ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto unsqueeze = std::dynamic_pointer_cast (m.get_match_root()); if (!unsqueeze) { return false; } auto prior_box_node = std::dynamic_pointer_cast (unsqueeze->input_value(0).get_node_shared_ptr()); - if (!prior_box_node) { + if (!prior_box_node || m_transformation_callback(prior_box_node)) { return false; } @@ -172,14 +172,14 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() { auto prior_box = std::make_shared(data, image, attr); auto unsqueeze = std::make_shared (prior_box, axes); - ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto unsqueeze = std::dynamic_pointer_cast (m.get_match_root()); if (!unsqueeze) { return false; } auto prior_box_node = std::dynamic_pointer_cast(unsqueeze->input_value(0).get_node_shared_ptr()); - if (!prior_box_node) { + if (!prior_box_node || m_transformation_callback(prior_box_node)) { return false; }