From 9ce8ac536f4a3df56dec90371070f0b41c4efd95 Mon Sep 17 00:00:00 2001 From: Egor Shulman Date: Tue, 23 Nov 2021 15:55:52 +0300 Subject: [PATCH] [CPU] Optimized shapeInfer() for OneHot (#8739) --- .../nodes/mkldnn_one_hot_node.cpp | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_one_hot_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_one_hot_node.cpp index fdc5d95e662..dbd6f0fafc6 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_one_hot_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_one_hot_node.cpp @@ -84,25 +84,12 @@ bool MKLDNNOneHotNode::needShapeInfer() const { } std::vector MKLDNNOneHotNode::shapeInfer() const { - std::vector input_shapes = { - getParentEdgesAtPort(0)[0]->getMemory().GetShape().getStaticDims(), - getParentEdgesAtPort(1)[0]->getMemory().GetShape().getStaticDims(), - getParentEdgesAtPort(2)[0]->getMemory().GetShape().getStaticDims(), - getParentEdgesAtPort(3)[0]->getMemory().GetShape().getStaticDims() - }; - std::map> input_values = { - {1, std::make_shared(ngraph::element::Type_t::i32, VectorDims{ }, getParentEdgesAtPort(1)[0]->getMemory().GetPtr())}, - {2, std::make_shared(opToShapeInfer->get_input_node_shared_ptr(2))}, - {3, std::make_shared(opToShapeInfer->get_input_node_shared_ptr(3))} - }; - std::vector output_shapes = {{}}; - shape_inference(opToShapeInfer.get(), input_shapes, output_shapes, input_values); - - std::vector result(output_shapes.size()); - std::transform(output_shapes.begin(), output_shapes.end(), result.begin(), [](const ov::StaticShape& s){ return s.to_shape(); }); - depth = reinterpret_cast(getParentEdgesAtPort(1)[0]->getMemoryPtr()->GetPtr())[0]; - return result; + + auto result = getParentEdgesAtPort(0)[0]->getMemory().getStaticDims(); + result.insert(result.begin() + axis, depth); + + return { result }; } void MKLDNNOneHotNode::initSupportedPrimitiveDescriptors() {