[Snippets] Fixed input and output original precisions in Subgraph (#13451)
This commit is contained in:
parent
c6528ee4ea
commit
db0e2be481
@ -221,8 +221,8 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
|
||||
"Failed to create broadcastable shapes in snippets canonicalization");
|
||||
const auto paramShape = m_body->get_parameters()[i]->get_shape();
|
||||
const auto paramType = m_body->get_parameters()[i]->get_element_type();
|
||||
if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin()) || paramType != inType)
|
||||
m_body->replace_parameter(i, std::make_shared<opset1::Parameter>(inType, inShape));
|
||||
if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin()))
|
||||
m_body->replace_parameter(i, std::make_shared<opset1::Parameter>(paramType, inShape));
|
||||
}
|
||||
|
||||
m_body->validate_nodes_and_infer_types();
|
||||
@ -269,43 +269,37 @@ Shape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& outputShape
|
||||
|
||||
void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outputShapes,
|
||||
const BlockedShapeVector& inputShapes) {
|
||||
// We should insert Convert before Results to set original output element type if needed
|
||||
const auto& body_results = m_body->get_results();
|
||||
for (size_t i = 0; i < outputShapes.size(); i++) {
|
||||
const auto needed_out_type = std::get<2>(outputShapes[i]);
|
||||
|
||||
// If there is real Convert from graph (ConvertTruncation) or after FQ decomp (ConvertSaturation) before Result
|
||||
// we should check destination type and insert ConvertSaturation before that if needed.
|
||||
// For example, to return original element type after Convert insertion on inputs
|
||||
std::shared_ptr<ov::Node> first_convert = body_results[i];
|
||||
while (ov::is_type<ngraph::op::v0::Convert>(first_convert->get_input_node_ptr(0))) {
|
||||
first_convert = first_convert->get_input_node_shared_ptr(0);
|
||||
}
|
||||
if (auto existing_convert_t = ngraph::as_type_ptr<ngraph::op::v0::Convert>(first_convert)) {
|
||||
const auto original_input_element_type = existing_convert_t->get_input_element_type(0);
|
||||
if (original_input_element_type != execution_element_type) {
|
||||
const auto convert = std::make_shared<ngraph::snippets::op::ConvertSaturation>(
|
||||
existing_convert_t->get_input_node_shared_ptr(0), original_input_element_type);
|
||||
existing_convert_t->set_argument(0, convert);
|
||||
}
|
||||
}
|
||||
|
||||
// We should insert Convert before Results to return original output element type
|
||||
const auto convert = std::make_shared<ngraph::snippets::op::ConvertSaturation>(
|
||||
if (body_results[i]->get_input_element_type(0) != needed_out_type) {
|
||||
const auto convert = std::make_shared<ngraph::snippets::op::ConvertSaturation>(
|
||||
body_results[i]->get_input_node_shared_ptr(0), needed_out_type);
|
||||
body_results[i]->set_argument(0, convert);
|
||||
body_results[i]->set_argument(0, convert);
|
||||
}
|
||||
}
|
||||
|
||||
// We should change existing element type to original for Parameters if needed
|
||||
const auto& body_parameters = m_body->get_parameters();
|
||||
for (size_t i = 0; i < inputShapes.size(); ++i) {
|
||||
const auto needed_in_type = std::get<2>(inputShapes[i]);
|
||||
if (body_parameters[i]->get_element_type() != needed_in_type) {
|
||||
body_parameters[i]->set_element_type(needed_in_type);
|
||||
config.m_is_needed_to_align_precision = true;
|
||||
}
|
||||
}
|
||||
|
||||
// We should align element type inside body using the corresponding pass:
|
||||
// - Insert Convert before operations that doesn't support original element type for execution
|
||||
// - Insert reverse Convert before operations that support original element type
|
||||
// but have inputs that doesn't support it (because before them will be inserted Convert with exec_type - first point)
|
||||
// Then we should use ConstantFolding pass to convert element type of Scalars before inference.
|
||||
// At the end eliminate redundant Convert that could be inserted
|
||||
ngraph::pass::Manager manager;
|
||||
if (config.m_is_needed_to_align_precision) {
|
||||
manager.register_pass<snippets::pass::AlignElementType>(execution_element_type);
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
}
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::EliminateConvert>();
|
||||
manager.run_passes(m_body);
|
||||
}
|
||||
|
||||
|
@ -15,22 +15,18 @@
|
||||
|
||||
namespace {
|
||||
|
||||
auto is_in_out_op(const std::shared_ptr<ov::Node>& n) -> bool {
|
||||
inline auto is_in_op(const std::shared_ptr<ov::Node>& n) -> bool {
|
||||
return ov::is_type<ov::op::v0::Parameter>(n)
|
||||
|| ov::is_type<ov::op::v0::Constant>(n)
|
||||
|| ov::is_type<ov::op::v0::Result>(n);
|
||||
|| ov::is_type<ov::op::v0::Constant>(n);
|
||||
}
|
||||
|
||||
// At the moment Subgraph supports only Eltwise, Convert and FQ (which is decomposed into Eltwises and Convert)
|
||||
// And only Eltwises supports execution only in "exec_type". So we can check op type from the opposite
|
||||
auto op_supports_only_exec_type(const std::shared_ptr<ov::Node>& n) -> bool {
|
||||
return !ov::is_type<ov::op::v0::Convert>(n);
|
||||
}
|
||||
|
||||
// Check if executable operation supports only execution element type f32
|
||||
// NOTE: Executable op is node that isn't Parameter/Constant/Result
|
||||
auto is_executable_op_only_on_exec_type(const std::shared_ptr<ov::Node>& n) -> bool {
|
||||
return op_supports_only_exec_type(n) && !is_in_out_op(n);
|
||||
// NOTE: This check is only for executable which isn't Parameter/Constant/Result
|
||||
inline auto op_supports_only_exec_type(const std::shared_ptr<ov::Node>& n) -> bool {
|
||||
return !is_in_op(n) &&
|
||||
!ov::is_type<ov::op::v0::Result>(n) &&
|
||||
!ov::is_type<ov::op::v0::Convert>(n);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -50,7 +46,7 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
|
||||
bool rewritten = false;
|
||||
auto ops = m->get_ordered_ops();
|
||||
for (auto& op : ops) {
|
||||
if (is_in_out_op(op) || ov::is_type<ov::op::v0::Convert>(op)) {
|
||||
if (is_in_op(op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -63,7 +59,8 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
|
||||
// - Input is Op which support any element type
|
||||
// We couldn't unite these conditions and just check that element type isn't supported exec type
|
||||
// because we don't call validate_and_infer_types() so we don't know new precisions
|
||||
if ((existing_convert && existing_convert->get_destination_type() != exec_type) || (!is_executable_op_only_on_exec_type(shared_input))) {
|
||||
if ((existing_convert && existing_convert->get_destination_type() != exec_type) ||
|
||||
(!op_supports_only_exec_type(shared_input))) {
|
||||
insertConvert(op, i, exec_type);
|
||||
rewritten |= true;
|
||||
}
|
||||
@ -72,7 +69,7 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
|
||||
tr_node->set_overridden_output_type(exec_type, 0);
|
||||
rewritten |= true;
|
||||
}
|
||||
} else { // branch for the Movement ops and MatMul ops in the future
|
||||
} else { // branch for Movement ops, MatMul ops in the future and for the Convert, Result
|
||||
for (auto i = 0; i < op->inputs().size(); i++) {
|
||||
auto shared_input = op->get_input_node_shared_ptr(i);
|
||||
// it's original element type because we don't use validate_and_infer_type() anywhere
|
||||
@ -80,7 +77,7 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
|
||||
// If before op there is another op that doesn't support execution on original element type, we know that
|
||||
// before this op will be inserted reverse Convert to support execution on supported element type (first branch of condition).
|
||||
// So we should return original element type for operations that can support low precision
|
||||
if (is_executable_op_only_on_exec_type(shared_input) && original_eltype != exec_type) {
|
||||
if (op_supports_only_exec_type(shared_input) && original_eltype != exec_type) {
|
||||
insertConvert(op, i, original_eltype);
|
||||
rewritten |= true;
|
||||
}
|
||||
@ -93,5 +90,5 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
|
||||
|
||||
bool ngraph::snippets::pass::AlignElementType::opNeedsAlignElementType(const std::shared_ptr<ov::Node>& op, const ov::element::Type exec_type) {
|
||||
// At the moment Snippets support only Eltwise/Convert/FQ which one output so we can just call get_element_type()
|
||||
return is_executable_op_only_on_exec_type(op) && op->get_element_type() != exec_type;
|
||||
return op_supports_only_exec_type(op) && op->get_element_type() != exec_type;
|
||||
}
|
@ -13,30 +13,52 @@ namespace snippets {
|
||||
namespace {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Eltwise, Add,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::Shape {1, 42, 16, 1}),
|
||||
::testing::Values(1), // one node - Add
|
||||
::testing::Values(0), // SnippetsMarkSkipped disables tokenization for eltwise chains after inputs
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
Add::getTestCaseName);
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::Shape {1, 42, 16, 1}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1), // one node - Add
|
||||
::testing::Values(0), // SnippetsMarkSkipped disables tokenization for eltwise chains after inputs
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
Add::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Eltwise, AddSinh,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::Shape {1, 42, 16, 1}),
|
||||
::testing::Values(3), // Add + 2 converts after inputs
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(3), // Add + 2 sinh after inputs
|
||||
::testing::Values(1), // Subgraph is created, since the inputs are followed by converts
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
AddSinh::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Eltwise, AddSinhConst,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(2), // Add + 2 converts after inputs
|
||||
::testing::Values(1), // Subgraph is created, since the inputs are followed by converts
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
AddSinhConst::getTestCaseName);
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(2), // Add + sinh after inputs
|
||||
::testing::Values(1), // Subgraph is created, since the inputs are followed by converts
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
AddSinhConst::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Eltwise, AddRollConst,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(2), // Add + roll after inputs
|
||||
::testing::Values(1), // Subgraph is created, since the inputs are followed by converts
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
AddRollConst::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Eltwise_BF16, AddRollConst,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::Shape {1, 42, 16, 64}),
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::Values(3), // Add + reorder + roll after inputs
|
||||
::testing::Values(1), // Subgraph is created, since the inputs are followed by converts
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
AddRollConst::getTestCaseName);
|
||||
|
||||
|
||||
} // namespace
|
||||
} // namespace snippets
|
||||
|
@ -13,6 +13,7 @@ namespace snippets {
|
||||
typedef std::tuple<
|
||||
ov::Shape, // Input 0 Shape
|
||||
ov::Shape, // Input 1 Shape
|
||||
ov::element::Type, // Element type
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
@ -20,6 +21,7 @@ typedef std::tuple<
|
||||
|
||||
typedef std::tuple<
|
||||
ov::Shape, // Input 0 Shape
|
||||
ov::element::Type, // Element type
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
@ -47,6 +49,11 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class AddRollConst : public AddSinhConst {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -12,13 +12,15 @@ namespace snippets {
|
||||
|
||||
std::string Add::getTestCaseName(testing::TestParamInfo<ov::test::snippets::AddParams> obj) {
|
||||
ov::Shape inputShapes0, inputShapes1, newInputShapes;
|
||||
ov::element::Type type;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::tie(inputShapes0, inputShapes1, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::tie(inputShapes0, inputShapes1, type, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS[0]=" << CommonTestUtils::vec2str(inputShapes0) << "_";
|
||||
result << "IS[1]=" << CommonTestUtils::vec2str(inputShapes1) << "_";
|
||||
result << "T=" << type << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
result << "#S=" << num_subgraphs << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
@ -27,30 +29,36 @@ std::string Add::getTestCaseName(testing::TestParamInfo<ov::test::snippets::AddP
|
||||
|
||||
void Add::SetUp() {
|
||||
ov::Shape inputShape0, inputShape1;
|
||||
std::tie(inputShape0, inputShape1, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type type;
|
||||
std::tie(inputShape0, inputShape1, type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes({{{}, {inputShape0, }}, {{}, {inputShape1, }}});
|
||||
|
||||
auto f = ov::test::snippets::AddFunction({inputShape0, inputShape1});
|
||||
function = f.getOriginal();
|
||||
setInferenceType(type);
|
||||
}
|
||||
|
||||
void AddSinh::SetUp() {
|
||||
ov::Shape inputShape0, inputShape1;
|
||||
std::tie(inputShape0, inputShape1, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type type;
|
||||
std::tie(inputShape0, inputShape1, type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes({{{}, {inputShape0, }}, {{}, {inputShape1, }}});
|
||||
|
||||
auto f = ov::test::snippets::AddSinhFunction({inputShape0, inputShape1});
|
||||
function = f.getOriginal();
|
||||
setInferenceType(type);
|
||||
}
|
||||
|
||||
std::string AddSinhConst::getTestCaseName(testing::TestParamInfo<ov::test::snippets::AddConstParams> obj) {
|
||||
ov::Shape inputShapes, newInputShapes;
|
||||
ov::element::Type type;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::tie(inputShapes, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::tie(inputShapes, type, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS[0]=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "T=" << type << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
result << "#S=" << num_subgraphs << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
@ -59,11 +67,24 @@ std::string AddSinhConst::getTestCaseName(testing::TestParamInfo<ov::test::snipp
|
||||
|
||||
void AddSinhConst::SetUp() {
|
||||
ov::Shape inputShape;
|
||||
std::tie(inputShape, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type type;
|
||||
std::tie(inputShape, type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes({{{}, {inputShape, }}});
|
||||
|
||||
auto f = ov::test::snippets::AddSinhConstFunction({inputShape});
|
||||
function = f.getOriginal();
|
||||
setInferenceType(type);
|
||||
}
|
||||
|
||||
void AddRollConst::SetUp() {
|
||||
ov::Shape inputShape;
|
||||
ov::element::Type type;
|
||||
std::tie(inputShape, type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
init_input_shapes({{{}, {inputShape, }}});
|
||||
|
||||
auto f = ov::test::snippets::AddRollConstFunction({inputShape});
|
||||
function = f.getOriginal();
|
||||
setInferenceType(type);
|
||||
}
|
||||
|
||||
TEST_P(Add, CompareWithRefImpl) {
|
||||
@ -81,6 +102,12 @@ TEST_P(AddSinhConst, CompareWithRefImpl) {
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(AddRollConst, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -15,6 +15,8 @@ protected:
|
||||
|
||||
void validateOriginalLayersNamesByType(const std::string& layerType, const std::string& originalLayersNames);
|
||||
|
||||
void setInferenceType(ov::element::Type type);
|
||||
|
||||
// Expected num nodes and subgraphs in exec graphs depends on the plugin
|
||||
// pipeline, tokenization callback for example. Therefore, they have to be provided manually.
|
||||
size_t ref_num_nodes = 0;
|
||||
|
@ -53,6 +53,9 @@ void SnippetsTestsCommon::validateOriginalLayersNamesByType(const std::string& l
|
||||
|
||||
ASSERT_TRUE(false) << "Layer type '" << layerType << "' was not found in compiled model";
|
||||
}
|
||||
void SnippetsTestsCommon::setInferenceType(ov::element::Type type) {
|
||||
configuration.emplace(ov::hint::inference_precision(type));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -61,6 +61,23 @@ protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
// std::shared_ptr<ov::Model> initReference() const override;
|
||||
};
|
||||
// Function is to check for different model precision
|
||||
/// Like AddSinhConst but with a Roll instead of Sinh because Roll is movement operation which
|
||||
// supports different precisions but Sinh supports only FP32 in CPU Plugin
|
||||
// in1
|
||||
// Roll Const
|
||||
// Add
|
||||
// Result
|
||||
// The function is needed to check different input element types (model precision change)
|
||||
class AddRollConstFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit AddRollConstFunction(const std::vector<Shape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 1, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
// std::shared_ptr<ov::Model> initReference() const override;
|
||||
};
|
||||
/// Simple Eltwise graph fully convertible to Subgraph.
|
||||
/// Tokenized simply by attaching eltwises.
|
||||
// in1 in2
|
||||
|
@ -54,6 +54,19 @@ std::shared_ptr<ov::Model> AddSinhConstFunction::initOriginal() const {
|
||||
auto add = std::make_shared<op::v1::Add>(sin0, const_data1);
|
||||
return std::make_shared<ov::Model>(NodeVector{add}, ParameterVector{data0});
|
||||
}
|
||||
std::shared_ptr<ov::Model> AddRollConstFunction::initOriginal() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
const std::vector<float> const_values = CommonTestUtils::generate_float_numbers(shape_size(input_shapes[0]), -10., 10.);
|
||||
auto const_data1 = std::make_shared<op::v0::Constant>(precision, input_shapes[0], const_values);
|
||||
auto shift = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<float>{1});
|
||||
auto axes = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<float>{0});
|
||||
auto roll0 = std::make_shared<ov::op::v7::Roll>(data0, shift, axes);
|
||||
auto add = std::make_shared<op::v1::Add>(roll0, const_data1);
|
||||
// The limitation for BF16 in CPU Plugin:
|
||||
roll0->get_rt_info()["enforceBF16evenForGraphTail"] = true;
|
||||
add->get_rt_info()["enforceBF16evenForGraphTail"] = true;
|
||||
return std::make_shared<ov::Model>(NodeVector{add}, ParameterVector{data0});
|
||||
}
|
||||
std::shared_ptr<ov::Model> EltwiseFunction::initOriginal() const {
|
||||
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
|
||||
|
Loading…
Reference in New Issue
Block a user