diff --git a/inference-engine/src/low_precision_transformations/src/reshape.cpp b/inference-engine/src/low_precision_transformations/src/reshape.cpp index aeba26b7e49..11d098aebab 100644 --- a/inference-engine/src/low_precision_transformations/src/reshape.cpp +++ b/inference-engine/src/low_precision_transformations/src/reshape.cpp @@ -12,6 +12,7 @@ #include #include +#include #include "low_precision/common/ie_lpt_exception.hpp" #include "low_precision/network_helper.hpp" @@ -23,13 +24,29 @@ namespace low_precision { NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::ReshapeTransformation, "ReshapeTransformation", 0); ReshapeTransformation::ReshapeTransformation(const Params& params) : LayerTransformation(params) { - auto matcher = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); + auto input = pattern::any_input(); + auto mul_const_m = pattern::wrap_type(); + auto mul_m = pattern::wrap_type({ input, mul_const_m }); + auto reshape_pattern_const = pattern::wrap_type(); + auto reshape_pattern_nonconst = pattern::any_input(); + auto reshape_pattern = std::make_shared(OutputVector{ reshape_pattern_const, reshape_pattern_nonconst }); + auto matcher = pattern::wrap_type({ mul_m, reshape_pattern }); - ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) { auto op = m.get_match_root(); if (transformation_callback(op)) { return false; } + + // we can propagate only per-tensor dq through reshape with non-const reshape_pattern + const auto& pattern_map = m.get_pattern_value_map(); + if (pattern_map.count(reshape_pattern_nonconst)) { + const auto mul_const = as_type_ptr(pattern_map.at(mul_const_m).get_node_shared_ptr()); + if (!mul_const || ngraph::shape_size(mul_const->get_shape()) != 1) { + return false; + } + } + return transform(*context, m); }; diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp index 8cb86c7e374..2310528bc12 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp @@ -42,7 +42,7 @@ public: }; ngraph::PartialShape inputShape; - std::vector reshapeConstValues; + std::vector reshapeConstValues; // if empty then create shapeOf TestTransformationParams params; Actual actual; Expected expected; @@ -962,6 +962,38 @@ const std::vector testValues = { {} } }, + // U8: non-const reshape pattern and per-tensor dequantization + { + { -1, -1, -1, -1 }, + {}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {0.1f}} + }, + { + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {0.1f}} + } + }, + // U8: non-const reshape pattern and per-channel dequantization + { + { -1, 3, -1, -1 }, + {}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}} + }, + { + ngraph::element::u8, + {{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}}, + ngraph::element::f32, + {} + } + }, }; INSTANTIATE_TEST_SUITE_P( diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/reshape_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/reshape_function.cpp index e03dae4b27a..2a794846af8 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/reshape_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/reshape_function.cpp @@ -20,10 +20,14 @@ std::shared_ptr ReshapeFunction::getOriginal( const std::shared_ptr dequantizationOp = makeDequantization(input, dequantization); - const std::shared_ptr reshape = std::make_shared( - dequantizationOp, - std::make_shared(ngraph::element::i64, ngraph::Shape{ reshapeConstValues.size() }, reshapeConstValues), - true); + std::shared_ptr reshape_pattern; + if (!reshapeConstValues.empty()) { + reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues); + } else { + reshape_pattern = std::make_shared(dequantizationOp); + } + + const auto reshape = std::make_shared(dequantizationOp, reshape_pattern, true); reshape->set_friendly_name("output"); ngraph::ResultVector results{ std::make_shared(reshape) }; @@ -61,11 +65,15 @@ std::shared_ptr ReshapeFunction::getReference( const std::shared_ptr quantizationOpBefore = makeDequantization(input, dequantizationBefore); - const std::shared_ptr reshapeConstant = std::make_shared( - ngraph::element::i64, - ngraph::Shape{ reshapeConstValues.size() }, - reshapeConstValues); - const std::shared_ptr reshape = std::make_shared(quantizationOpBefore, reshapeConstant, true); + std::shared_ptr reshape_pattern; + if (!reshapeConstValues.empty()) { + reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues); + } else { + reshape_pattern = makeDequantization(quantizationOpBefore, dequantizationAfter); + reshape_pattern = std::make_shared(reshape_pattern); + } + + const auto reshape = std::make_shared(quantizationOpBefore, reshape_pattern, true); if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) { THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation"; }