[CPU] Add input type check into in-place condition (#20529)

This commit is contained in:
Aleksandr Voron 2023-10-24 16:45:48 +02:00 committed by GitHub
parent 63fff9d270
commit c6707aab86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 1 deletions

View File

@ -450,7 +450,9 @@ void Edge::init() {
DEBUG_LOG(*this, " getBaseEdge() return itself"); DEBUG_LOG(*this, " getBaseEdge() return itself");
changeStatus(Status::NeedAllocation); changeStatus(Status::NeedAllocation);
} else { } else {
if (edgePtr->getParent()->isConstant() && !edgePtr->getChild()->isConstant()) { if (Type::Input == edgePtr->getParent()->getType() &&
edgePtr->getParent()->isConstant() &&
!edgePtr->getChild()->isConstant()) {
changeStatus(Status::NeedAllocation); changeStatus(Status::NeedAllocation);
DEBUG_LOG(*this, " edge inplace from ", *edgePtr, " is broken!"); DEBUG_LOG(*this, " edge inplace from ", *edgePtr, " is broken!");
return; return;

View File

@ -0,0 +1,77 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <shared_test_classes/base/ov_subgraph.hpp>
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "ov_models/utils/ov_helpers.hpp"
#include "ov_models/builders.hpp"
using namespace CPUTestUtils;
using namespace InferenceEngine;
namespace SubgraphTestsDefinitions {
// If a node (CumSum) with constant parents has several non-constant nodes (Eltwises) than the edge is broken.
// The fix is to check node type - is should be Input.
// Subgraph:
/*
* Constant Constant
* \ /
* \ /
* CumSum
* Parameter / \ Parameter
* \ / \ /
* \ / \ /
* Eltwise Eltwise
* \ /
* Eltwise
* |
* Result
*/
using namespace ov::test;
class NonInputInPlaceTest : public testing::WithParamInterface<ElementType>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<ElementType> obj) {
std::ostringstream result;
result << "NonInputInPlaceTest_inPrc=outPrc=" << obj.param;
return result.str();
}
void SetUp() override {
targetDevice = utils::DEVICE_CPU;
configuration.insert({ov::hint::inference_precision.name(), ov::element::f16.to_string()});
const std::vector<size_t> inputShape = {1, 11, 3, 3};
targetStaticShapes = {{inputShape, inputShape}};
ElementType prc = this->GetParam();
ov::ParameterVector inputParams {std::make_shared<ov::op::v0::Parameter>(prc, ov::Shape(inputShape)),
std::make_shared<ov::op::v0::Parameter>(prc, ov::Shape(inputShape))};
auto cumsum_tensor = ngraph::opset8::Constant::create(prc, inputShape, {10.0f});
auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, {}, {0});
const auto cumsum = std::make_shared<ov::op::v0::CumSum>(cumsum_tensor, axis_node);
auto eltwiseMul = ngraph::builder::makeEltwise(inputParams[0], cumsum, ngraph::helpers::EltwiseTypes::MULTIPLY);
auto eltwiseAdd1 = ngraph::builder::makeEltwise(inputParams[1], cumsum, ngraph::helpers::EltwiseTypes::ADD);
auto eltwiseAdd2 = ngraph::builder::makeEltwise(eltwiseAdd1, eltwiseMul, ngraph::helpers::EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset8::Result>(eltwiseAdd2)};
function = std::make_shared<ngraph::Function>(results, inputParams, "NonInputInPlaceT");
}
};
namespace {
TEST_P(NonInputInPlaceTest, CompareWithRefs) {
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_NonInputInPlaceTest_CPU, NonInputInPlaceTest,
testing::Values(ngraph::element::f32, ngraph::element::f16),
NonInputInPlaceTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions