[Snippets] Fixed copy runtime info which contains PortDescriptors (#17774)

This commit is contained in:
Alexandra Sidorova
2023-06-08 16:22:48 +04:00
committed by GitHub
parent b1b1014a34
commit 385cfee24a
11 changed files with 93 additions and 17 deletions

View File

@@ -82,6 +82,9 @@ private:
static void init_default(std::vector<PortDescriptorPtr>& in_descs, std::vector<PortDescriptorPtr>& out_descs, const std::shared_ptr<ov::Node>& node);
};
// PortDescriptorVectorAttribute is not copyable attribute!
// It's needed to avoid incorrect copies of rt info between different nodes in call copy_runtime_info() (for example, in transformations)
// The attribute must be manually copied if needed
class PortDescriptorVectorAttribute : public ov::RuntimeAttribute {
public:
OPENVINO_RTTI("PortDescriptorVectorAttribute", "", ov::RuntimeAttribute);
@@ -90,6 +93,8 @@ public:
explicit PortDescriptorVectorAttribute(std::vector<PortDescriptorPtr> in_descs = {}, std::vector<PortDescriptorPtr> out_descs = {})
: inputs(std::move(in_descs)), outputs(std::move(out_descs)) {}
bool is_copyable() const override { return false; }
std::vector<PortDescriptorPtr> inputs{};
std::vector<PortDescriptorPtr> outputs{};
};

View File

@@ -28,9 +28,6 @@ ov::PartialShape get_port_planar_shape(const Input<Node>& out);
ov::PartialShape get_port_planar_shape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
// Copy runtime info using default ngraph method but delete PortDescriptors which may be transferred after copying
void safe_copy_runtime_info(const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>& to);
inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank;
}

View File

@@ -4,8 +4,8 @@
#include "snippets/itt.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "openvino/core/rt_info.hpp"
ov::snippets::pass::ConvertPowerToPowerStatic::ConvertPowerToPowerStatic() {
@@ -22,7 +22,7 @@ ov::snippets::pass::ConvertPowerToPowerStatic::ConvertPowerToPowerStatic() {
auto value = scalar->cast_vector<float>()[0];
auto power_static = std::make_shared<snippets::op::PowerStatic>(power->input(0).get_source_output(), value);
power_static->set_friendly_name(power->get_friendly_name());
utils::safe_copy_runtime_info(power, power_static);
copy_runtime_info(power, power_static);
ov::replace_node(power, power_static);
return true;

View File

@@ -10,6 +10,7 @@
#include "openvino/opsets/opset1.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/core/rt_info.hpp"
#include <numeric>
@@ -45,7 +46,7 @@ ov::Output<ov::Node> ov::snippets::pass::InsertMoveBroadcast::BroadcastNodeLastD
ov::PartialShape broadcasted_shape = normalized_shape;
*broadcasted_shape.rbegin() = *target_shape.rbegin();
const auto broadcast_node = std::make_shared<ov::snippets::op::BroadcastMove>(value, broadcasted_shape);
utils::safe_copy_runtime_info(value.get_node_shared_ptr(), broadcast_node);
copy_runtime_info(value.get_node_shared_ptr(), broadcast_node);
return broadcast_node->output(0);
}

View File

@@ -7,6 +7,7 @@
#include "ov_ops/type_relaxed.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
#include "openvino/core/rt_info.hpp"
#include <assert.h>
#include <memory>
@@ -130,7 +131,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
parent_output,
required_after);
utils::safe_copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
op->set_argument(op_input.get_index(), convert);
continue;
}
@@ -149,7 +150,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
existing_convert->get_input_node_shared_ptr(0),
required_after);
utils::safe_copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
op->set_argument(op_input.get_index(), convert);
continue;
}
@@ -158,7 +159,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
existing_convert->output(0),
required_after);
utils::safe_copy_runtime_info(existing_convert->output(0).get_node()->shared_from_this(), convert);
copy_runtime_info(existing_convert->output(0).get_node()->shared_from_this(), convert);
op->set_argument(op_input.get_index(), convert);
}
}
@@ -180,7 +181,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
result->get_input_node_shared_ptr(0),
expected_type);
utils::safe_copy_runtime_info(result->get_input_node_shared_ptr(0), convert);
copy_runtime_info(result->get_input_node_shared_ptr(0), convert);
result->set_argument(0, convert);
}
}
@@ -223,7 +224,7 @@ bool ov::snippets::pass::PropagatePrecision::validate_and_infer_types_and_restor
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
output,
op_output_types[i]);
utils::safe_copy_runtime_info(output.get_node_shared_ptr(), convert);
copy_runtime_info(output.get_node_shared_ptr(), convert);
for (auto& input : output.get_target_inputs()) {
auto child = input.get_node();

View File

@@ -97,11 +97,6 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out) {
return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout());
}
void safe_copy_runtime_info(const std::shared_ptr<ov::Node>& from, const std::shared_ptr<ov::Node>& to) {
ov::copy_runtime_info(from, to);
lowered::PortDescriptorUtils::clean(to);
}
} // namespace utils
} // namespace snippets
} // namespace ov

View File

@@ -64,6 +64,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHA,
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAMulAdd, MHAMulAdd,
::testing::Combine(
::testing::Values(std::vector<ov::PartialShape>{{1, 10, 12, 16}, {1, 10, 12, 16}, {1, 10, 12, 16}}),
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false}), // Need to support True for graph builder in tests
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
// without broadcast
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 12, 128, 128}, {1, 12, 128, 128}, {1, 128, 12, 64}},

View File

@@ -67,6 +67,11 @@ protected:
void init_subgraph() override;
};
class MHAMulAdd : public MHA {
void init_subgraph() override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@@ -129,6 +129,11 @@ void MHAFQ::init_subgraph() {
function = f.getOriginal();
}
void MHAMulAdd::init_subgraph() {
auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes);
function = f.getOriginal();
}
TEST_P(MHA, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
@@ -152,6 +157,11 @@ TEST_P(MHAWOTranspose, CompareWithRefImpl) {
validateNumSubgraphs();
}
TEST_P(MHAMulAdd, CompareWithRefImpl) {
run();
validateNumSubgraphs();
}
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
run();
validateNumSubgraphs();

View File

@@ -268,6 +268,27 @@ protected:
std::shared_ptr<ov::Model> initReference() const override;
};
/* Graph:
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
* \ /
* MatMul0
* \
* Multiply
* Add
* Softmax Transpose2[0,2,1,3]
* \ /
* MatMul1
* Transpose3[0,2,1,3]
*/
class MHAMulAddFunction : public SnippetsFunctionBase {
public:
explicit MHAMulAddFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@@ -361,7 +361,6 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
@@ -657,6 +656,35 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
return std::make_shared<ov::Model>(NodeVector{transpose3}, ngraphParams);
}
std::shared_ptr<ov::Model> MHAMulAddFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, transpose2Param};
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[0].size()}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[1].size()}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
auto mulConst = ngraph::builder::makeConstant(ngraph::element::f32, matMul0->get_shape(), std::vector<float>{}, true);
auto addConst = ngraph::builder::makeConstant(ngraph::element::f32, matMul0->get_shape(), std::vector<float>{}, true);
const auto mul = std::make_shared<ngraph::opset3::Multiply>(matMul0, mulConst);
const auto add = std::make_shared<ngraph::opset3::Add>(mul, addConst);
const auto softMax = std::make_shared<ov::op::v8::Softmax>(add, -1);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
} // namespace snippets
} // namespace test
} // namespace ov