diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.cpp index dc6001ad74a..ba9140385c1 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.cpp @@ -17,10 +17,6 @@ using namespace InferenceEngine; bool MKLDNNSelectNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - if (isDynamicNgraphNode(op)) { - errorMessage = "Doesn't support op with dynamic shapes"; - return false; - } const auto select = std::dynamic_pointer_cast(op); if (!select) { errorMessage = "Only opset1 Select operation is supported"; @@ -47,7 +43,7 @@ MKLDNNSelectNode::MKLDNNSelectNode(const std::shared_ptr& op, cons errorPrefix = "Select layer with name '" + op->get_friendly_name() + "'"; const auto select = std::dynamic_pointer_cast(op); - if (op->get_input_size() != numOfInputs || op->get_output_size() != 1) + if (inputShapes.size() != numOfInputs || outputShapes.size() != 1) IE_THROW() << errorPrefix << " has incorrect number of input/output edges!"; const auto broadcast = select->get_auto_broadcast(); @@ -59,56 +55,43 @@ MKLDNNSelectNode::MKLDNNSelectNode(const std::shared_ptr& op, cons IE_THROW() << errorPrefix << " has unsupported broadcast type: " + ngraph::as_string(broadcast.m_type); } - auto conditionShapes = op->get_input_shape(CONDITION); - if (ngraph::is_scalar(conditionShapes)) - conditionShapes = ngraph::Shape{1}; - auto thenShapes = op->get_input_shape(THEN); - if (ngraph::is_scalar(thenShapes)) - thenShapes = ngraph::Shape{1}; - auto elseShapes = op->get_input_shape(ELSE); - if (ngraph::is_scalar(elseShapes)) - elseShapes = ngraph::Shape{1}; - auto outputShapes = op->get_output_shape(0); - if (ngraph::is_scalar(outputShapes)) - outputShapes = ngraph::Shape{1}; + const auto &conditionShape = getInputShapeAtPort(CONDITION).getDims(); + const auto &thenShape = getInputShapeAtPort(THEN).getDims(); + const auto &elseShape = getInputShapeAtPort(ELSE).getDims(); + const auto &outputShape = getOutputShapeAtPort(0).getDims(); - if (broadcastType == SelectBroadcastType::NONE && ((conditionShapes != outputShapes) || (thenShapes != outputShapes) || - (elseShapes != outputShapes))) + if (broadcastType == SelectBroadcastType::NONE && (!dimsEqualWeak(conditionShape, outputShape) || !dimsEqualWeak(thenShape, outputShape) || + !dimsEqualWeak(elseShape, outputShape))) { IE_THROW() << errorPrefix << " and auto_broadcast='none' has input shapes mismatch"; + } if (broadcastType == SelectBroadcastType::NUMPY) { - if (outputShapes.size() < conditionShapes.size() || outputShapes.size() < thenShapes.size() || outputShapes.size() < elseShapes.size()) + if (outputShape.size() < conditionShape.size() || outputShape.size() < thenShape.size() || outputShape.size() < elseShape.size()) IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible input and output shapes"; - for (int condIt = conditionShapes.size() - 1, outIt = outputShapes.size() - 1; condIt >= 0; condIt--, outIt--) - if (conditionShapes[condIt] != outputShapes[outIt] && conditionShapes[condIt] != 1) + for (int condIt = conditionShape.size() - 1, outIt = outputShape.size() - 1; condIt >= 0; condIt--, outIt--) + if (!dimsEqualWeak(conditionShape[condIt], outputShape[outIt]) && !dimsEqualWeak(conditionShape[condIt], 1)) IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes"; - for (int thenIt = thenShapes.size() - 1, outIt = outputShapes.size() - 1; thenIt >= 0; thenIt--, outIt--) - if (thenShapes[thenIt] != outputShapes[outIt] && thenShapes[thenIt] != 1) + for (int thenIt = thenShape.size() - 1, outIt = outputShape.size() - 1; thenIt >= 0; thenIt--, outIt--) + if (!dimsEqualWeak(thenShape[thenIt], outputShape[outIt]) && !dimsEqualWeak(thenShape[thenIt], 1)) IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Then' input and output shapes"; - for (int elseIt = elseShapes.size() - 1, outIt = outputShapes.size() - 1; elseIt >= 0; elseIt--, outIt--) - if (elseShapes[elseIt] != outputShapes[outIt] && elseShapes[elseIt] != 1) + for (int elseIt = elseShape.size() - 1, outIt = outputShape.size() - 1; elseIt >= 0; elseIt--, outIt--) + if (!dimsEqualWeak(elseShape[elseIt], outputShape[outIt]) && !dimsEqualWeak(elseShape[elseIt], 1)) IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Else' input and output shapes"; } resDims.resize(numOfDims, 1); - std::copy(std::begin(outputShapes), std::end(outputShapes), std::begin(resDims) + (numOfDims - outputShapes.size())); if (broadcastType == SelectBroadcastType::NUMPY) { - calcOutOffset(resOffset, resDims); + resOffset.resize(numOfDims); + condOffset.resize(numOfDims); + thenOffset.resize(numOfDims); + elseOffset.resize(numOfDims); - std::vector condDims(numOfDims, 1); - std::copy(std::begin(conditionShapes), std::end(conditionShapes), std::begin(condDims) + (numOfDims - conditionShapes.size())); - calcInOffset(condOffset, condDims, resDims); - - std::vector thenDims(numOfDims, 1); - std::copy(std::begin(thenShapes), std::end(thenShapes), std::begin(thenDims) + (numOfDims - thenShapes.size())); - calcInOffset(thenOffset, thenDims, resDims); - - std::vector elseDims(numOfDims, 1); - std::copy(std::begin(elseShapes), std::end(elseShapes), std::begin(elseDims) + (numOfDims - elseShapes.size())); - calcInOffset(elseOffset, elseDims, resDims); + condDims.resize(numOfDims, 1); + thenDims.resize(numOfDims, 1); + elseDims.resize(numOfDims, 1); } } @@ -140,8 +123,47 @@ void MKLDNNSelectNode::initSupportedPrimitiveDescriptors() { impl_desc_type::ref_any); } -void MKLDNNSelectNode::calcOutOffset(std::vector& offset, const std::vector& dims) { - offset.resize(numOfDims); +void MKLDNNSelectNode::prepareParams() { + if (!inputShapesDefined()) { + IE_THROW() << "Can't prepare params for eltwise node with name: " << getName(); + } + + const auto &_conditionDims = getParentEdgesAtPort(CONDITION)[0]->getMemory().getStaticDims(); + const auto &_thenDims = getParentEdgesAtPort(THEN)[0]->getMemory().getStaticDims(); + const auto &_elseDims = getParentEdgesAtPort(ELSE)[0]->getMemory().getStaticDims(); + const auto &_outputDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims(); + + std::fill(resDims.begin(), resDims.end(), 1); + std::copy(std::begin(_outputDims), std::end(_outputDims), std::begin(resDims) + (numOfDims - _outputDims.size())); + if (broadcastType == SelectBroadcastType::NUMPY) { + std::fill(resOffset.begin(), resOffset.end(), 1); + calcOutOffset(resOffset, resDims); + + std::fill(condDims.begin(), condDims.end(), 1); + std::copy(std::begin(_conditionDims), std::end(_conditionDims), std::begin(condDims) + (numOfDims - _conditionDims.size())); + std::fill(condOffset.begin(), condOffset.end(), 1); + calcInOffset(condOffset, condDims, resDims); + + std::fill(thenDims.begin(), thenDims.end(), 1); + std::copy(std::begin(_thenDims), std::end(_thenDims), std::begin(thenDims) + (numOfDims - _thenDims.size())); + std::fill(thenOffset.begin(), thenOffset.end(), 1); + calcInOffset(thenOffset, thenDims, resDims); + + std::fill(elseDims.begin(), elseDims.end(), 1); + std::copy(std::begin(_elseDims), std::end(_elseDims), std::begin(elseDims) + (numOfDims - _elseDims.size())); + std::fill(elseOffset.begin(), elseOffset.end(), 1); + calcInOffset(elseOffset, elseDims, resDims); + } +} + +void MKLDNNSelectNode::createPrimitive() { + if (inputShapesDefined()) { + prepareParams(); + updateLastInputDims(); + } +} + +void MKLDNNSelectNode::calcOutOffset(VectorDims& offset, const VectorDims& dims) { int k = 1; for (int i = dims.size() - 1; i >= 0; i--) { offset[i] = k; @@ -149,8 +171,7 @@ void MKLDNNSelectNode::calcOutOffset(std::vector& offset, const std::vec } } -void MKLDNNSelectNode::calcInOffset(std::vector& offset, const std::vector& inDims, const std::vector& outDims) { - offset.resize(numOfDims); +void MKLDNNSelectNode::calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims) { int k = 1; for (int i = inDims.size() - 1; i >= 0; i--) { offset[i] = (inDims[i] == outDims[i]) ? k : 0; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.h index f6e84a34de9..6602195f122 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_select_node.h @@ -18,10 +18,13 @@ public: void getSupportedDescriptors() override {}; void initSupportedPrimitiveDescriptors() override; - void createPrimitive() override {}; + void createPrimitive() override; void execute(mkldnn::stream strm) override; bool created() const override; + void executeDynamicImpl(mkldnn::stream strm) override { execute(strm); } + void prepareParams() override; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: @@ -33,16 +36,20 @@ private: }; SelectBroadcastType broadcastType; - std::vector resDims; - std::vector resOffset; - std::vector condOffset; - std::vector thenOffset; - std::vector elseOffset; + VectorDims resDims; + VectorDims resOffset; + VectorDims condOffset; + VectorDims thenOffset; + VectorDims elseOffset; + + VectorDims condDims; + VectorDims thenDims; + VectorDims elseDims; std::string errorPrefix; - void calcOutOffset(std::vector& offset, const std::vector& dims); - void calcInOffset(std::vector& offset, const std::vector& inDims, const std::vector& outDims); + void calcOutOffset(VectorDims& offset, const VectorDims& dims); + void calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims); template void execute_impl(); }; diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/select.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/select.cpp new file mode 100644 index 00000000000..e7e92b78c27 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/select.cpp @@ -0,0 +1,174 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils/cpu_test_utils.hpp" +#include "ngraph_functions/builders.hpp" + +using namespace ngraph; +using namespace InferenceEngine; +using namespace CPUTestUtils; + +namespace CPULayerTestsDefinitions { + +using selectParams = std::tuple< + std::pair, std::vector>>, // input shapes + ngraph::op::AutoBroadcastSpec>; // broadcast + +class SelectLayerCPUTest : public testing::WithParamInterface, public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + std::pair, std::vector>> shapes; + ngraph::op::AutoBroadcastSpec broadcast; + std::tie(shapes, broadcast) = obj.param; + + std::ostringstream result; + result << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_"; + result << "TS="; + for (const auto& shape : shapes.second) { + result << "("; + for (const auto& item : shape) { + result << CommonTestUtils::vec2str(item) << "_"; + } + result << ")_"; + } + result << "Broadcast=" << broadcast.m_type; + + return result.str(); + } + +protected: + void SetUp() override { + targetDevice = CommonTestUtils::DEVICE_CPU; + + std::pair, std::vector>> shapes; + ngraph::op::AutoBroadcastSpec broadcast; + std::tie(shapes, broadcast) = this->GetParam(); + + for (size_t i = 0; i < shapes.second.size(); i++) { + targetStaticShapes.push_back(shapes.second[i]); + } + inputDynamicShapes = shapes.first; + + selectedType = std::string("ref_any_") + Precision(Precision::I8).name(); + + ngraph::ParameterVector paramNodesVector; + auto paramNode = std::make_shared(ngraph::element::Type_t::boolean, ngraph::Shape(targetStaticShapes[0][0])); + paramNodesVector.push_back(paramNode); + auto inType = ngraph::element::Type_t::f32; + for (size_t i = 1; i < targetStaticShapes[0].size(); i++) { + paramNode = std::make_shared(inType, ngraph::Shape(targetStaticShapes[0][i])); + paramNodesVector.push_back(paramNode); + } + auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(paramNodesVector)); + + auto select = ngraph::builder::makeSelect(paramOuts, broadcast); + + function = std::make_shared(select, paramNodesVector, "SelectLayerCPUTest"); + functionRefs = ngraph::clone_function(*function); + } +}; + +TEST_P(SelectLayerCPUTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + Run(); + CheckPluginRelatedResults(executableNetwork, "Select"); +} + +std::vector, std::vector>>> inShapesDynamicNumpy = { + { + // dynamic + { + {-1, -1, -1, -1}, + {-1, -1, -1, -1, -1}, + {-1, -1, -1, -1} + }, + + // target + { + {{5, 1, 2, 1}, {8, 1, 9, 1, 1}, {5, 1, 2, 1}}, + {{1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1}}, + {{5, 9, 8, 7}, {21, 5, 9, 8, 7}, {1, 1, 1, 1}}, + } + }, + { + // dynamic + { + {-1, -1}, + {-1, -1, -1, -1, -1}, + {-1, -1, -1} + }, + + // target + { + {{8, 1}, {2, 1, 1, 8, 1}, {9, 1, 1}}, + {{10, 5}, {7, 8, 3, 10, 5}, {3, 10, 5}}, + {{8, 7}, {1, 1, 1, 8, 1}, {1, 1, 7}}, + } + }, + { + // dynamic + { + {{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}}, + {-1, -1, -1, -1, -1}, + {{1, 5}, {1, 11}, {5, 5}, {1, 8}} + }, + + // target + { + {{5, 4, 1, 1, 1}, {5, 1, 8, 1, 1}, {1, 1, 5, 1}}, + {{8, 5, 5, 5, 1}, {8, 1, 1, 1, 8}, {5, 5, 5, 8}}, + {{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {3, 4, 5, 6}}, + } + }, + { + // dynamic + { + {{1, 10}}, + {{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}}, + {{2, 10}, {1, 16}} + }, + + // target + { + {{4}, {8, 5, 6, 6, 1}, {6, 4}}, + {{10}, {15, 7, 6, 10, 10}, {10, 10}}, + {{1}, {2, 5, 4, 5, 3}, {5, 1}}, + } + } +}; + +const auto numpyCases = ::testing::Combine( + ::testing::ValuesIn(inShapesDynamicNumpy), + ::testing::Values(ngraph::op::AutoBroadcastSpec::NUMPY) +); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName); + +std::vector, std::vector>>> inShapesDynamicNone = { + { + // dynamic + { + {{1, 10}, -1, {10, 20}, {1, 5}}, + {-1, {16, 16}, -1, -1}, + {-1, -1, -1, -1} + }, + + // target + { + {{3, 16, 15, 5}, {3, 16, 15, 5}, {3, 16, 15, 5}}, + {{1, 16, 10, 1}, {1, 16, 10, 1}, {1, 16, 10, 1}}, + {{10, 16, 20, 5}, {10, 16, 20, 5}, {10, 16, 20, 5}} + } + } +}; + +const auto noneCases = ::testing::Combine( + ::testing::ValuesIn(inShapesDynamicNone), + ::testing::Values(ngraph::op::AutoBroadcastSpec::NONE) +); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName); + +} // namespace CPULayerTestsDefinitions \ No newline at end of file