[BUG fix] Reshape node: WA in-place failure case by mem-copy (#10828)
* Handle in-place failure cases in reshape node * Disable inplace when non-const reshape connected to constant * Add comment to reshape_inplace test * move copy WA into execute() to cover more general in-place failure cases
This commit is contained in:
parent
a571539107
commit
3f9c6b2f3f
@ -111,6 +111,12 @@ void Reshape::initSupportedPrimitiveDescriptors() {
|
|||||||
if (inPrec != outPrec)
|
if (inPrec != outPrec)
|
||||||
inPrec = outPrec;
|
inPrec = outPrec;
|
||||||
|
|
||||||
|
bool canBeInPlace = true;
|
||||||
|
|
||||||
|
// CVS-81059 : disable inPlace in following case since it won't be satisfied by framework
|
||||||
|
if (!isConstant() && getParentEdgeAt(0)->getParent()->isConstant())
|
||||||
|
canBeInPlace = false;
|
||||||
|
|
||||||
NodeConfig config;
|
NodeConfig config;
|
||||||
config.dynBatchSupport = true;
|
config.dynBatchSupport = true;
|
||||||
config.inConfs.resize(getParentEdges().size());
|
config.inConfs.resize(getParentEdges().size());
|
||||||
@ -121,7 +127,7 @@ void Reshape::initSupportedPrimitiveDescriptors() {
|
|||||||
config.inConfs[i].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc((i > 0 ? secondInPrc : inPrec), getInputShapeAtPort(i)));
|
config.inConfs[i].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc((i > 0 ? secondInPrc : inPrec), getInputShapeAtPort(i)));
|
||||||
}
|
}
|
||||||
config.outConfs.resize(1);
|
config.outConfs.resize(1);
|
||||||
config.outConfs[0].inPlace(0);
|
config.outConfs[0].inPlace(canBeInPlace ? 0 : -1);
|
||||||
config.outConfs[0].constant(false);
|
config.outConfs[0].constant(false);
|
||||||
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(outPrec, getOutputShapeAtPort(0)));
|
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(outPrec, getOutputShapeAtPort(0)));
|
||||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||||
@ -131,6 +137,24 @@ void Reshape::executeDynamicImpl(dnnl::stream strm) {
|
|||||||
execute(strm);
|
execute(strm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Reshape::execute(dnnl::stream strm) {
|
||||||
|
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
|
||||||
|
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||||
|
|
||||||
|
auto srcPtr = static_cast<uint8_t*>(srcMemPtr->GetPtr());
|
||||||
|
auto dstPtr = static_cast<uint8_t*>(dstMemPtr->GetPtr());
|
||||||
|
|
||||||
|
if (dstPtr != srcPtr) {
|
||||||
|
cpu_memcpy(dstPtr, srcPtr, dstMemPtr->GetSize());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Reshape::isExecutable() const {
|
||||||
|
bool inPlaceEnabled =
|
||||||
|
getSelectedPrimitiveDescriptor() && getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].inPlace() >= 0;
|
||||||
|
return !inPlaceEnabled;
|
||||||
|
}
|
||||||
|
|
||||||
bool Reshape::created() const {
|
bool Reshape::created() const {
|
||||||
return getType() == Type::Reshape;
|
return getType() == Type::Reshape;
|
||||||
}
|
}
|
||||||
|
@ -22,14 +22,13 @@ public:
|
|||||||
void getSupportedDescriptors() override;
|
void getSupportedDescriptors() override;
|
||||||
void initSupportedPrimitiveDescriptors() override;
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
bool isExecutable() const override {
|
bool isExecutable() const override;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool needShapeInfer() const override;
|
bool needShapeInfer() const override;
|
||||||
std::vector<VectorDims> shapeInfer() const override;
|
std::vector<VectorDims> shapeInfer() const override;
|
||||||
bool needPrepareParams() const override { return false; }
|
bool needPrepareParams() const override { return false; }
|
||||||
void executeDynamicImpl(dnnl::stream strm) override;
|
void executeDynamicImpl(dnnl::stream strm) override;
|
||||||
|
void execute(dnnl::stream strm) override;
|
||||||
|
|
||||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
|
||||||
|
@ -0,0 +1,90 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||||
|
#include "ngraph/runtime/aligned_buffer.hpp"
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||||
|
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||||
|
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||||
|
|
||||||
|
using namespace InferenceEngine;
|
||||||
|
using namespace ov::test;
|
||||||
|
namespace SubgraphTestsDefinitions {
|
||||||
|
// Subgraph:
|
||||||
|
/*
|
||||||
|
* params[0] params[1]
|
||||||
|
* | |
|
||||||
|
* constant shapeOf /
|
||||||
|
* \ | /
|
||||||
|
* broadcast /
|
||||||
|
* \ /
|
||||||
|
* \ /
|
||||||
|
* reshape
|
||||||
|
* |
|
||||||
|
* result
|
||||||
|
*
|
||||||
|
* This test is designed for correctness of reshape's in-place implementation.
|
||||||
|
*
|
||||||
|
* Due to non-const target shape parameter (params[1]), reshape node
|
||||||
|
* is non-constant node even though the input tensor is constant node.
|
||||||
|
*
|
||||||
|
* some logic protecting constant data from being corrupted by
|
||||||
|
* the in-place consumer may breaks the in-place assumption, and reshape
|
||||||
|
* should be able to handle this case correctly.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class InPlaceReshapeFromConstantCheck : public SubgraphBaseTest {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
const auto rtPrc = ov::element::f32;
|
||||||
|
const ov::Shape inpShape = {21660, 4};
|
||||||
|
const ov::Shape secShape = {4};
|
||||||
|
ngraph::ParameterVector params(2);
|
||||||
|
targetStaticShapes = {{inpShape, secShape}};
|
||||||
|
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||||
|
params[0] = ngraph::builder::makeParams(rtPrc, {inpShape})[0];
|
||||||
|
params[1] = ngraph::builder::makeParams(ov::element::i32, {secShape})[0];
|
||||||
|
auto shape = std::make_shared<ov::opset8::ShapeOf>(params[0]);
|
||||||
|
auto c = ngraph::builder::makeConstant<float>(rtPrc, {}, {1.0f});
|
||||||
|
auto broadcast = std::make_shared<ov::opset8::Broadcast>(c, shape);
|
||||||
|
auto reshape = std::make_shared<ov::opset8::Reshape>(broadcast, params[1], false);
|
||||||
|
ov::ResultVector results{std::make_shared<ngraph::opset1::Result>(reshape->output(0))};
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "reshape_check");
|
||||||
|
}
|
||||||
|
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||||
|
inputs.clear();
|
||||||
|
const auto& funcInputs = function->inputs();
|
||||||
|
for (int i = 0; i < funcInputs.size(); ++i) {
|
||||||
|
const auto& funcInput = funcInputs[i];
|
||||||
|
ov::runtime::Tensor tensor;
|
||||||
|
if (i == 1) {
|
||||||
|
tensor = ov::runtime::Tensor{ov::element::i32, targetInputStaticShapes[i]};
|
||||||
|
auto inputData = tensor.data<ov::element_type_traits<ov::element::i32>::value_type>();
|
||||||
|
const std::vector<unsigned> data = {38, 38, 15, 4};
|
||||||
|
for (size_t j = 0lu; j < data.size(); ++j) {
|
||||||
|
inputData[j] = data[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (funcInput.get_element_type().is_real()) {
|
||||||
|
tensor = utils::create_and_fill_tensor(funcInput.get_element_type(),
|
||||||
|
targetInputStaticShapes[i],
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
1000);
|
||||||
|
} else {
|
||||||
|
tensor = utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(InPlaceReshapeFromConstantCheck, smoke_CPU_InPlaceReshapeFromConstantCheck) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
run();
|
||||||
|
}
|
||||||
|
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user