[Snippets] Fixed copy runtime info which contains PortDescriptors (#17774)
This commit is contained in:
committed by
GitHub
parent
b1b1014a34
commit
385cfee24a
@@ -82,6 +82,9 @@ private:
|
||||
static void init_default(std::vector<PortDescriptorPtr>& in_descs, std::vector<PortDescriptorPtr>& out_descs, const std::shared_ptr<ov::Node>& node);
|
||||
};
|
||||
|
||||
// PortDescriptorVectorAttribute is not copyable attribute!
|
||||
// It's needed to avoid incorrect copies of rt info between different nodes in call copy_runtime_info() (for example, in transformations)
|
||||
// The attribute must be manually copied if needed
|
||||
class PortDescriptorVectorAttribute : public ov::RuntimeAttribute {
|
||||
public:
|
||||
OPENVINO_RTTI("PortDescriptorVectorAttribute", "", ov::RuntimeAttribute);
|
||||
@@ -90,6 +93,8 @@ public:
|
||||
explicit PortDescriptorVectorAttribute(std::vector<PortDescriptorPtr> in_descs = {}, std::vector<PortDescriptorPtr> out_descs = {})
|
||||
: inputs(std::move(in_descs)), outputs(std::move(out_descs)) {}
|
||||
|
||||
bool is_copyable() const override { return false; }
|
||||
|
||||
std::vector<PortDescriptorPtr> inputs{};
|
||||
std::vector<PortDescriptorPtr> outputs{};
|
||||
};
|
||||
|
||||
@@ -28,9 +28,6 @@ ov::PartialShape get_port_planar_shape(const Input<Node>& out);
|
||||
ov::PartialShape get_port_planar_shape(const Output<Node>& out);
|
||||
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
|
||||
|
||||
// Copy runtime info using default ngraph method but delete PortDescriptors which may be transferred after copying
|
||||
void safe_copy_runtime_info(const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>& to);
|
||||
|
||||
inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
|
||||
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank;
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/pass/convert_power_to_powerstatic.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
|
||||
|
||||
ov::snippets::pass::ConvertPowerToPowerStatic::ConvertPowerToPowerStatic() {
|
||||
@@ -22,7 +22,7 @@ ov::snippets::pass::ConvertPowerToPowerStatic::ConvertPowerToPowerStatic() {
|
||||
auto value = scalar->cast_vector<float>()[0];
|
||||
auto power_static = std::make_shared<snippets::op::PowerStatic>(power->input(0).get_source_output(), value);
|
||||
power_static->set_friendly_name(power->get_friendly_name());
|
||||
utils::safe_copy_runtime_info(power, power_static);
|
||||
copy_runtime_info(power, power_static);
|
||||
ov::replace_node(power, power_static);
|
||||
|
||||
return true;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -45,7 +46,7 @@ ov::Output<ov::Node> ov::snippets::pass::InsertMoveBroadcast::BroadcastNodeLastD
|
||||
ov::PartialShape broadcasted_shape = normalized_shape;
|
||||
*broadcasted_shape.rbegin() = *target_shape.rbegin();
|
||||
const auto broadcast_node = std::make_shared<ov::snippets::op::BroadcastMove>(value, broadcasted_shape);
|
||||
utils::safe_copy_runtime_info(value.get_node_shared_ptr(), broadcast_node);
|
||||
copy_runtime_info(value.get_node_shared_ptr(), broadcast_node);
|
||||
|
||||
return broadcast_node->output(0);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
#include <memory>
|
||||
@@ -130,7 +131,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
|
||||
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
|
||||
parent_output,
|
||||
required_after);
|
||||
utils::safe_copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
op->set_argument(op_input.get_index(), convert);
|
||||
continue;
|
||||
}
|
||||
@@ -149,7 +150,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
|
||||
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
|
||||
existing_convert->get_input_node_shared_ptr(0),
|
||||
required_after);
|
||||
utils::safe_copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
op->set_argument(op_input.get_index(), convert);
|
||||
continue;
|
||||
}
|
||||
@@ -158,7 +159,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
|
||||
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
|
||||
existing_convert->output(0),
|
||||
required_after);
|
||||
utils::safe_copy_runtime_info(existing_convert->output(0).get_node()->shared_from_this(), convert);
|
||||
copy_runtime_info(existing_convert->output(0).get_node()->shared_from_this(), convert);
|
||||
op->set_argument(op_input.get_index(), convert);
|
||||
}
|
||||
}
|
||||
@@ -180,7 +181,7 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
|
||||
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
|
||||
result->get_input_node_shared_ptr(0),
|
||||
expected_type);
|
||||
utils::safe_copy_runtime_info(result->get_input_node_shared_ptr(0), convert);
|
||||
copy_runtime_info(result->get_input_node_shared_ptr(0), convert);
|
||||
result->set_argument(0, convert);
|
||||
}
|
||||
}
|
||||
@@ -223,7 +224,7 @@ bool ov::snippets::pass::PropagatePrecision::validate_and_infer_types_and_restor
|
||||
auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
|
||||
output,
|
||||
op_output_types[i]);
|
||||
utils::safe_copy_runtime_info(output.get_node_shared_ptr(), convert);
|
||||
copy_runtime_info(output.get_node_shared_ptr(), convert);
|
||||
|
||||
for (auto& input : output.get_target_inputs()) {
|
||||
auto child = input.get_node();
|
||||
|
||||
@@ -97,11 +97,6 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out) {
|
||||
return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout());
|
||||
}
|
||||
|
||||
void safe_copy_runtime_info(const std::shared_ptr<ov::Node>& from, const std::shared_ptr<ov::Node>& to) {
|
||||
ov::copy_runtime_info(from, to);
|
||||
lowered::PortDescriptorUtils::clean(to);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
||||
@@ -64,6 +64,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHA,
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAMulAdd, MHAMulAdd,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 10, 12, 16}, {1, 10, 12, 16}, {1, 10, 12, 16}}),
|
||||
::testing::ValuesIn(precision_f32(3)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({false}), // Need to support True for graph builder in tests
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
|
||||
// without broadcast
|
||||
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 12, 128, 128}, {1, 12, 128, 128}, {1, 128, 12, 64}},
|
||||
|
||||
@@ -67,6 +67,11 @@ protected:
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAMulAdd : public MHA {
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
@@ -129,6 +129,11 @@ void MHAFQ::init_subgraph() {
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
void MHAMulAdd::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
TEST_P(MHA, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
@@ -152,6 +157,11 @@ TEST_P(MHAWOTranspose, CompareWithRefImpl) {
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAMulAdd, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
|
||||
@@ -268,6 +268,27 @@ protected:
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
|
||||
* \ /
|
||||
* MatMul0
|
||||
* \
|
||||
* Multiply
|
||||
* Add
|
||||
* Softmax Transpose2[0,2,1,3]
|
||||
* \ /
|
||||
* MatMul1
|
||||
* Transpose3[0,2,1,3]
|
||||
*/
|
||||
class MHAMulAddFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAMulAddFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
@@ -361,7 +361,6 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
@@ -657,6 +656,35 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
|
||||
return std::make_shared<ov::Model>(NodeVector{transpose3}, ngraphParams);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAMulAddFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, transpose2Param};
|
||||
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[0].size()}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[1].size()}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
|
||||
auto mulConst = ngraph::builder::makeConstant(ngraph::element::f32, matMul0->get_shape(), std::vector<float>{}, true);
|
||||
auto addConst = ngraph::builder::makeConstant(ngraph::element::f32, matMul0->get_shape(), std::vector<float>{}, true);
|
||||
const auto mul = std::make_shared<ngraph::opset3::Multiply>(matMul0, mulConst);
|
||||
const auto add = std::make_shared<ngraph::opset3::Add>(mul, addConst);
|
||||
const auto softMax = std::make_shared<ov::op::v8::Softmax>(add, -1);
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
Reference in New Issue
Block a user