[CPU] optimize shape infer of Reshape (#16537)
* add reshape shapeinfer in cpu plugin Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * add squeeze and unsqueeze Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * add precision i8 i64 on test Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix code out of bounds risk Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * test performance of this PR Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix code issue Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * Revert "test performance of this PR" This reverts commit f4f9f002de28d03bc1c55c24067f75b74824904c. * fix reviewer comment fix throw message not create ov::shape instance remove i8 test case Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix pytorch layer test failed issue inputShape(1,0) outpattern(-1) is a valid input Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix windows compile issue Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix rebase mistaken Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> --------- Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
parent
05ab0f32d7
commit
75c62ea320
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "reshape.h"
|
||||
#include "utils.hpp"
|
||||
#include <string>
|
||||
#include <dnnl_types.h>
|
||||
#include <dnnl_extension_utils.h>
|
||||
@ -10,6 +11,7 @@
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include "utils/shape_inference/shape_inference_cpu.hpp"
|
||||
|
||||
#include "common/cpu_memcpy.h"
|
||||
|
||||
@ -34,8 +36,193 @@ bool Reshape::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ReshapeShapeInfer : public ShapeInferEmptyPads {
|
||||
public:
|
||||
ReshapeShapeInfer(bool specialZero) : m_specialZero(specialZero) {}
|
||||
Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
|
||||
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
|
||||
static constexpr size_t RESHAPE_SRC = 0, RESHAPE_PATTERN = 1;
|
||||
const auto& inputShape = input_shapes[RESHAPE_SRC].get();
|
||||
const size_t inputShapeSize = inputShape.size();
|
||||
const auto memPtr = data_dependency.at(RESHAPE_PATTERN);
|
||||
const auto data = memPtr->GetPtr();
|
||||
// const auto outputPatternSize = shape_size(ov::Shape(memPtr->getStaticDims()));
|
||||
const auto& dims = memPtr->getStaticDims();
|
||||
const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<Dim>());
|
||||
std::vector<int64_t> outPattern = ov::get_raw_data_as<int64_t>(
|
||||
InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||
data,
|
||||
outputPatternSize,
|
||||
ov::util::Cast<int64_t>());
|
||||
VectorDims outputShape(outputPatternSize);
|
||||
size_t outputProduct(1);
|
||||
int32_t minusOneIdx = -1;
|
||||
int32_t minusOneCount = 0;
|
||||
for (size_t i = 0; i < outputPatternSize; ++i) {
|
||||
if (outPattern[i] == 0 && m_specialZero && i < inputShapeSize) {
|
||||
outputShape[i] = inputShape[i];
|
||||
outputProduct *= outputShape[i];
|
||||
} else if (outPattern[i] == -1) {
|
||||
minusOneIdx = i;
|
||||
minusOneCount++;
|
||||
} else {
|
||||
outputShape[i] = outPattern[i];
|
||||
outputProduct *= outputShape[i];
|
||||
}
|
||||
}
|
||||
size_t inputProduct(1);
|
||||
for (size_t i = 0; i < inputShapeSize; ++i) {
|
||||
inputProduct *= inputShape[i];
|
||||
}
|
||||
if (outputProduct != 0 && minusOneIdx >= 0) {
|
||||
outputShape[minusOneIdx] = inputProduct / outputProduct;
|
||||
outputProduct *= outputShape[minusOneIdx];
|
||||
}
|
||||
if (minusOneCount > 1 || inputProduct != outputProduct) {
|
||||
IE_THROW(Unexpected) << "[cpu]reshape: the shape of input data conflicts with the reshape pattern";
|
||||
}
|
||||
return {{std::move(outputShape)}, ShapeInferStatus::success};
|
||||
}
|
||||
port_mask_t get_port_mask() const override {
|
||||
return PortMask(1);
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_specialZero;
|
||||
};
|
||||
|
||||
class SqueezeShapeInfer : public ShapeInferEmptyPads {
|
||||
public:
|
||||
SqueezeShapeInfer() {}
|
||||
Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
|
||||
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
|
||||
static constexpr size_t SQUEEZE_SRC = 0, SQUEEZE_PATTERN = 1;
|
||||
const auto& inputShape = input_shapes[SQUEEZE_SRC].get();
|
||||
const size_t inputShapeSize = inputShape.size();
|
||||
auto itr = data_dependency.find(SQUEEZE_PATTERN);
|
||||
VectorDims outputShape;
|
||||
if (itr != data_dependency.end()) {
|
||||
const auto memPtr = data_dependency.at(SQUEEZE_PATTERN);
|
||||
const auto data = memPtr->GetPtr();
|
||||
const auto& dims = memPtr->getStaticDims();
|
||||
const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<Dim>());
|
||||
std::vector<int64_t> outPattern = ov::get_raw_data_as<int64_t>(
|
||||
InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||
data,
|
||||
outputPatternSize,
|
||||
ov::util::Cast<int64_t>());
|
||||
std::vector<bool> removeMask(inputShapeSize, false);
|
||||
bool existError = false;
|
||||
for (size_t i = 0; i < outputPatternSize; i++) {
|
||||
if (outPattern[i] < 0) {
|
||||
outPattern[i] = inputShapeSize + outPattern[i];
|
||||
}
|
||||
if (outPattern[i] >= 0 && outPattern[i] < static_cast<int64_t>(inputShapeSize)) {
|
||||
removeMask[outPattern[i]] = true;
|
||||
} else {
|
||||
existError = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < inputShapeSize; i++) {
|
||||
if (!removeMask[i]) {
|
||||
outputShape.push_back(inputShape[i]);
|
||||
} else if (inputShape[i] != 1) {
|
||||
existError = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (existError) {
|
||||
IE_THROW(Unexpected) << "[cpu]squeeze: the shape of input data conflict with the squeeze pattern";
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < inputShapeSize; i++) {
|
||||
if (inputShape[i] != 1) {
|
||||
outputShape.push_back(inputShape[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return {{std::move(outputShape)}, ShapeInferStatus::success};
|
||||
}
|
||||
port_mask_t get_port_mask() const override {
|
||||
return PortMask(1);
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeShapeInfer : public ShapeInferEmptyPads {
|
||||
public:
|
||||
UnsqueezeShapeInfer() {}
|
||||
Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
|
||||
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
|
||||
static constexpr size_t UNSQUEEZE_SRC = 0, UNSQUEEZE_PATTERN = 1;
|
||||
const auto& inputShape = input_shapes[UNSQUEEZE_SRC].get();
|
||||
const size_t inputShapeSize = inputShape.size();
|
||||
const auto memPtr = data_dependency.at(UNSQUEEZE_PATTERN);
|
||||
const auto data = memPtr->GetPtr();
|
||||
const auto& dims = memPtr->getStaticDims();
|
||||
const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<Dim>());
|
||||
std::vector<int64_t> outPattern = ov::get_raw_data_as<int64_t>(
|
||||
InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||
data,
|
||||
outputPatternSize,
|
||||
ov::util::Cast<int64_t>());
|
||||
size_t outputShapeSize = inputShapeSize + outputPatternSize;
|
||||
VectorDims outputShape(outputShapeSize, 0);
|
||||
bool existError = false;
|
||||
for (size_t i = 0; i < outputPatternSize; i++) {
|
||||
if (outPattern[i] < 0) {
|
||||
outPattern[i] = outputShapeSize + outPattern[i];
|
||||
}
|
||||
if (outPattern[i] >= 0 && outPattern[i] < static_cast<int64_t>(outputShapeSize)) {
|
||||
outputShape[outPattern[i]] = 1;
|
||||
} else {
|
||||
existError = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0, y = 0; i < outputShapeSize; i++) {
|
||||
if (outputShape[i] == 0) {
|
||||
if (y < inputShapeSize) {
|
||||
outputShape[i] = inputShape[y];
|
||||
y++;
|
||||
} else {
|
||||
existError = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (existError) {
|
||||
IE_THROW(Unexpected) << "[cpu]unsqueeze: the shape of input data conflicts with the unsqueeze pattern";
|
||||
}
|
||||
return {{std::move(outputShape)}, ShapeInferStatus::success};
|
||||
}
|
||||
port_mask_t get_port_mask() const override {
|
||||
return PortMask(1);
|
||||
}
|
||||
};
|
||||
|
||||
class ReshapeShapeInferFactory : public ShapeInferFactory {
|
||||
public:
|
||||
ReshapeShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
|
||||
ShapeInferPtr makeShapeInfer() const override {
|
||||
if (const auto reshapeOp = ov::as_type_ptr<const ov::op::v1::Reshape>(m_op)) {
|
||||
return std::make_shared<ReshapeShapeInfer>(reshapeOp->get_special_zero());
|
||||
} else if (ov::is_type<ov::op::v0::Squeeze>(m_op)) {
|
||||
return std::make_shared<SqueezeShapeInfer>();
|
||||
} else if (ov::is_type<ov::op::v0::Unsqueeze>(m_op)) {
|
||||
return std::make_shared<UnsqueezeShapeInfer>();
|
||||
} else {
|
||||
IE_THROW(Unexpected) << "[cpu]reshape: " << m_op->get_type_name() << "is not implemented";
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::shared_ptr<ov::Node> m_op;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Reshape::Reshape(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
|
||||
Node(op, context, NgraphShapeInferFactory(op, PortMask(1))) {
|
||||
Node(op, context, ReshapeShapeInferFactory(op)) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
|
@ -45,6 +45,7 @@ using shapeOpsParams = std::tuple<
|
||||
ngraph::helpers::InputLayerType, // second input type
|
||||
shapeNodeType, // node type
|
||||
Precision, // precision
|
||||
ngraph::element::Type_t, // second input precision
|
||||
bool>; // special zero
|
||||
|
||||
class ShapeOpsCPUTest : public testing::WithParamInterface<shapeOpsParams>, virtual public SubgraphBaseTest, public CPUTestsBase {
|
||||
@ -55,7 +56,8 @@ public:
|
||||
shapeNodeType nodeType;
|
||||
Precision prc;
|
||||
bool specialZero;
|
||||
std::tie(inpDesc, secondType, nodeType, prc, specialZero) = obj.param;
|
||||
element::Type_t tmpSecondInPrc;
|
||||
std::tie(inpDesc, secondType, nodeType, prc, tmpSecondInPrc, specialZero) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << nodeType << "_";
|
||||
@ -72,6 +74,7 @@ public:
|
||||
}
|
||||
result << "PRC=" << prc << "_";
|
||||
result << "specialZero=" << specialZero;
|
||||
result << "_secondInPrc=" << tmpSecondInPrc;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
@ -84,10 +87,21 @@ protected:
|
||||
const auto& funcInput = funcInputs[i];
|
||||
ov::runtime::Tensor tensor;
|
||||
if (i == 1) {
|
||||
tensor = ov::runtime::Tensor{ov::element::i32, targetInputStaticShapes[i]};
|
||||
auto inputData = tensor.data<ov::element_type_traits<ov::element::i32>::value_type>();
|
||||
for (size_t j = 0lu; j < data[idx].size(); ++j) {
|
||||
inputData[j] = data[idx][j];
|
||||
#define RESHAPE_TEST_CASE(INT_TYPE) \
|
||||
case ov::element::Type_t::INT_TYPE: { \
|
||||
tensor = ov::runtime::Tensor{ov::element::INT_TYPE, targetInputStaticShapes[i]}; \
|
||||
auto inputData = tensor.data<ov::element_type_traits<ov::element::INT_TYPE>::value_type>(); \
|
||||
for (size_t j = 0lu; j < data[idx].size(); ++j) { \
|
||||
inputData[j] = data[idx][j]; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
switch (secondInPrc) {
|
||||
RESHAPE_TEST_CASE(i64)
|
||||
RESHAPE_TEST_CASE(i32)
|
||||
default:
|
||||
FAIL() << "We shouldn't get here.";
|
||||
#undef RESHAPE_TEST_CASE
|
||||
}
|
||||
} else {
|
||||
if (funcInput.get_element_type().is_real()) {
|
||||
@ -110,7 +124,7 @@ protected:
|
||||
shapeNodeType nodeType;
|
||||
Precision prc;
|
||||
bool specialZero;
|
||||
std::tie(inpDesc, secondType, nodeType, prc, specialZero) = this->GetParam();
|
||||
std::tie(inpDesc, secondType, nodeType, prc, secondInPrc, specialZero) = this->GetParam();
|
||||
|
||||
selectedType = std::string("unknown_") + prc.name();
|
||||
|
||||
@ -123,7 +137,6 @@ protected:
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc);
|
||||
const auto secondInPrc = ngraph::element::Type_t::i32;
|
||||
auto inputs = ngraph::builder::makeDynamicParams(ngPrc, {inputDynamicShapes.front()});
|
||||
auto dataInput = inputs.front();
|
||||
dataInput->set_friendly_name("param_1");
|
||||
@ -158,6 +171,7 @@ protected:
|
||||
private:
|
||||
std::vector<std::vector<int>> data;
|
||||
size_t idx;
|
||||
element::Type_t secondInPrc;
|
||||
};
|
||||
|
||||
TEST_P(ShapeOpsCPUTest, CompareWithRefs) {
|
||||
@ -166,6 +180,7 @@ TEST_P(ShapeOpsCPUTest, CompareWithRefs) {
|
||||
}
|
||||
|
||||
namespace reshapeTest {
|
||||
const std::vector<ov::element::Type_t> secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32};
|
||||
|
||||
inputDescription noBounds{{{-1, -1, -1, -1},
|
||||
{ngraph::Shape{2, 5, 7, 3}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{1, 2, 5, 5}}},
|
||||
@ -175,6 +190,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::PARAMETER),
|
||||
::testing::Values(shapeNodeType::Reshape),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic, ShapeOpsCPUTest, params, ShapeOpsCPUTest::getTestCaseName);
|
||||
@ -187,6 +203,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
|
||||
::testing::Values(shapeNodeType::Reshape),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName);
|
||||
@ -199,6 +216,7 @@ const auto params_dynBatch = ::testing::Combine(::testing::Values(shape_dynBatch
|
||||
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
|
||||
::testing::Values(shapeNodeType::Reshape),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynBatch, ShapeOpsCPUTest, params_dynBatch, ShapeOpsCPUTest::getTestCaseName);
|
||||
@ -206,7 +224,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynBatch, ShapeOpsCPUTest, params
|
||||
} // namespace reshapeTest
|
||||
|
||||
namespace squeezeTest {
|
||||
|
||||
const std::vector<ov::element::Type_t> secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32};
|
||||
inputDescription noBounds{{{-1, -1, -1, -1, -1, -1},
|
||||
{
|
||||
ngraph::Shape{2, 5, 1, 7, 3, 1},
|
||||
@ -220,6 +238,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::PARAMETER),
|
||||
::testing::Values(shapeNodeType::Squeeze),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
// at this momemnt squeze produce dynamic output rank, if second input is not constant
|
||||
@ -234,6 +253,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
|
||||
::testing::Values(shapeNodeType::Squeeze),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName);
|
||||
@ -241,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, p
|
||||
} // namespace squeezeTest
|
||||
|
||||
namespace unsqueezeTest {
|
||||
|
||||
const std::vector<ov::element::Type_t> secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32};
|
||||
inputDescription noBounds{{{-1, -1, -1, -1},
|
||||
{ngraph::Shape{2, 5, 7, 3}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{5, 1, 5}}},
|
||||
{std::vector<int>{2, 5}, std::vector<int>{1, 2}, std::vector<int>{4, 5}, std::vector<int>{0, 1}}};
|
||||
@ -250,6 +270,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::PARAMETER),
|
||||
::testing::Values(shapeNodeType::Unsqueeze),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
// at this momemnt unsqueze produce dynamic output rank, if second input is not constant
|
||||
@ -264,6 +285,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
|
||||
::testing::Values(shapeNodeType::Unsqueeze),
|
||||
::testing::Values(Precision::FP32),
|
||||
::testing::ValuesIn(secondInPrcs),
|
||||
::testing::Values(true));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName);
|
||||
|
Loading…
Reference in New Issue
Block a user