[HETERO][nGraph] Fix ConstantFolding fused_names handling (#13766)
* [HETERO][CPU][GPU][TEMPLATE][nGraph] Fix ConstantFolding fused_names propogation. Add tests for QN. Fix for unsupported comsumers in QN Fix ConvolutionMultiplyFusion for CPU Enable fused_names check for Constants Fix memory formats rt info for CPU * Update src/plugins/intel_cpu/src/utils/rt_info/memory_formats_attribute.hpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Update src/tests/unit/inference_engine/query_model_test.cpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Update src/plugins/intel_cpu/src/utils/rt_info/memory_formats_attribute.hpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Update src/plugins/intel_cpu/src/ngraph_transformations/convert_matmul_to_fc.cpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Update src/core/src/rt_info.cpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Update src/core/src/rt_info.cpp Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Fix review comments * Fix test * Code style * Review comments * Don't add Parameters/Inputs/Results to supported due to supported consumers/sours. Add tests for that. * Fix rt_info propogation for ConstantFolding * Fix test build * Code stye * Review comments. Adds ShapeOfBase. Co-authored-by: Ilya Churaev <ilyachur@gmail.com>
This commit is contained in:
parent
e09995bc60
commit
539f17df62
@ -67,7 +67,9 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
if (fq_users.size() == 1 && has_dequantization_subgraph(fq_users[0])) {
|
||||
auto& first_convert = fq_users[0];
|
||||
if (auto new_weights = ov::get_constant_from_source(first_convert)) {
|
||||
new_weights->set_friendly_name(first_convert->get_friendly_name());
|
||||
replace_node(first_convert, new_weights);
|
||||
copy_runtime_info(first_convert, new_weights);
|
||||
// preserve dequantization subgraph for LP transformations
|
||||
auto weights_users = new_weights->get_users();
|
||||
if (weights_users.size() == 1 && ov::is_type<ngraph::opset8::Convert>(weights_users[0])) {
|
||||
|
@ -118,6 +118,7 @@ static bool handle_variadic_split(const std::shared_ptr<ov::Node>& split) {
|
||||
const auto& split_lengths_type = split_lengths_node->get_output_element_type(0);
|
||||
const auto sub_const = ngraph::opset6::Constant::create(split_lengths_type, {sub_values.size()}, sub_values);
|
||||
const auto sub = std::make_shared<ngraph::opset6::Subtract>(split->input_value(2), sub_const);
|
||||
copy_runtime_info(split->get_input_source_output(2).get_node_shared_ptr(), {sub_const, sub});
|
||||
split->input(2).replace_source_output(sub);
|
||||
|
||||
return true;
|
||||
@ -259,6 +260,7 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
}
|
||||
auto new_const = opset6::Constant::create(const_node->get_element_type(), Shape{res.size()}, res);
|
||||
replace_node(const_node, new_const);
|
||||
copy_runtime_info(const_node, new_const);
|
||||
NGRAPH_DEBUG << "Transform shape like (" << last_output.get_node()->get_friendly_name()
|
||||
<< "): " << const_node->get_shape_val() << " to " << new_const->get_shape_val() << std::endl;
|
||||
new_const->set_friendly_name(const_node->get_friendly_name());
|
||||
@ -303,6 +305,7 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
for (auto consumer : consumers) {
|
||||
consumer.replace_source_output(last_output);
|
||||
}
|
||||
copy_runtime_info(const_node, last_output.get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
NGRAPH_DEBUG << "[ INFO ] TOTAL WEIGHTS: " << total_weights_count << std::endl;
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "transformations/common_optimizations/align_eltwise_input_ranks.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
ov::pass::AlignEltwiseInputRanks::AlignEltwiseInputRanks() {
|
||||
@ -50,6 +51,7 @@ ov::pass::AlignEltwiseInputRanks::AlignEltwiseInputRanks() {
|
||||
Shape new_shape = const_shape;
|
||||
new_shape.insert(new_shape.begin(), diff, 1);
|
||||
auto new_const = std::make_shared<opset8::Constant>(*const_node, new_shape);
|
||||
copy_runtime_info(node->get_input_node_shared_ptr(i), new_const);
|
||||
node->input(i).replace_source_output(new_const);
|
||||
}
|
||||
}
|
||||
|
@ -55,6 +55,7 @@ ov::pass::FoldSubgraphEmptyInputs::FoldSubgraphEmptyInputs() {
|
||||
std::end(multi_subgraph_op_inputs),
|
||||
input,
|
||||
const_empty_replacement);
|
||||
copy_runtime_info(input.get_node_shared_ptr(), const_empty_replacement.get_node_shared_ptr());
|
||||
}
|
||||
multi_subgraph_op->set_arguments(multi_subgraph_op_inputs);
|
||||
return true;
|
||||
|
@ -53,6 +53,7 @@ ov::pass::RemoveConcatZeroDimInput::RemoveConcatZeroDimInput() {
|
||||
const auto& empty_constant = opset8::Constant::create(concat->get_output_element_type(0),
|
||||
concat->get_output_partial_shape(0).to_shape(),
|
||||
{});
|
||||
copy_runtime_info(concat, empty_constant);
|
||||
concat->output(0).replace(empty_constant);
|
||||
empty_constant->set_friendly_name(concat->get_friendly_name());
|
||||
} else {
|
||||
|
@ -155,6 +155,10 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
Shape{},
|
||||
reversed_transposed_split_axis);
|
||||
split->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant,
|
||||
output_transpose.transpose->shared_from_this(),
|
||||
output_transpose.transpose_const->shared_from_this()},
|
||||
new_split_axis_const);
|
||||
|
||||
// remove split output transposes
|
||||
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
||||
@ -196,6 +200,8 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
auto new_split_axis_const =
|
||||
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
|
||||
split_node->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "transformations/smart_reshape/shape_of_const_folding.hpp"
|
||||
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
@ -19,6 +21,7 @@ ov::pass::ShapeOfConstFolding::ShapeOfConstFolding() {
|
||||
auto node = m.get_match_root();
|
||||
if (auto constant = get_constant_from_source(node)) {
|
||||
constant->set_friendly_name(node->get_friendly_name());
|
||||
copy_runtime_info(node, constant);
|
||||
replace_node(node, constant);
|
||||
return true;
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ public:
|
||||
using Base = std::tuple<::ov::RuntimeAttribute>;
|
||||
virtual ~RuntimeAttribute() = default;
|
||||
virtual bool is_copyable() const;
|
||||
virtual bool is_copyable(const std::shared_ptr<Node>& to) const;
|
||||
virtual Any init(const std::shared_ptr<Node>& node) const;
|
||||
virtual Any merge(const ov::NodeVector& nodes) const;
|
||||
virtual Any merge(const ov::OutputVector& outputs) const;
|
||||
|
@ -4,16 +4,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/op.hpp"
|
||||
#include "openvino/op/util/shape_of_base.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v3 {
|
||||
/// \brief Operation that returns the shape of its input argument as a tensor.
|
||||
/// \ingroup ov_ops_cpp_api
|
||||
class OPENVINO_API ShapeOf : public Op {
|
||||
class OPENVINO_API ShapeOf : public util::ShapeOfBase {
|
||||
public:
|
||||
OPENVINO_OP("ShapeOf", "opset3", op::Op, 3);
|
||||
OPENVINO_OP("ShapeOf", "opset3", util::ShapeOfBase, 3);
|
||||
ShapeOf() = default;
|
||||
/// \brief Constructs a shape-of operation.
|
||||
ShapeOf(const Output<Node>& arg, const element::Type output_type = element::i64);
|
||||
@ -51,9 +51,9 @@ private:
|
||||
namespace v0 {
|
||||
/// \brief Operation that returns the shape of its input argument as a tensor.
|
||||
/// \ingroup ov_ops_cpp_api
|
||||
class OPENVINO_API ShapeOf : public Op {
|
||||
class OPENVINO_API ShapeOf : public util::ShapeOfBase {
|
||||
public:
|
||||
OPENVINO_OP("ShapeOf", "opset1");
|
||||
OPENVINO_OP("ShapeOf", "opset1", util::ShapeOfBase);
|
||||
ShapeOf() = default;
|
||||
/// \brief Constructs a shape-of operation.
|
||||
ShapeOf(const Output<Node>& arg);
|
||||
|
23
src/core/include/openvino/op/util/shape_of_base.hpp
Normal file
23
src/core/include/openvino/op/util/shape_of_base.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/op.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class OPENVINO_API ShapeOfBase : public Op {
|
||||
public:
|
||||
OPENVINO_OP("ShapeOfBase", "util");
|
||||
|
||||
ShapeOfBase() = default;
|
||||
|
||||
/// \brief Constructs an ShapeOfBase operation.
|
||||
explicit ShapeOfBase(const OutputVector& arguments) : Op(arguments) {}
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -22,7 +22,7 @@ public:
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
|
||||
|
||||
protected:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
||||
void copy_runtime_info_from_input_values(const std::shared_ptr<Node>& node);
|
||||
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
||||
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
||||
bool pre_calculated_values_folding(const std::shared_ptr<ov::Model>& model);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <sstream>
|
||||
#include <typeindex>
|
||||
@ -814,8 +815,10 @@ bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& in
|
||||
if (!all_constants)
|
||||
return false;
|
||||
|
||||
NodeVector nodes;
|
||||
TensorVector input_tensors;
|
||||
for (const auto& input : input_values) {
|
||||
nodes.push_back(input.get_node_shared_ptr());
|
||||
auto constant = ov::as_type_ptr<ngraph::op::v0::Constant>(input.get_node_shared_ptr());
|
||||
auto tensor = ov::Tensor(input.get_element_type(), input.get_shape());
|
||||
std::copy_n(constant->get_data_ptr<uint8_t>(), constant->get_byte_size(), static_cast<uint8_t*>(tensor.data()));
|
||||
@ -833,6 +836,7 @@ bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& in
|
||||
output_values[i] = make_shared<ngraph::op::Constant>(output_tensors[i].get_element_type(),
|
||||
output_tensors[i].get_shape(),
|
||||
output_tensors[i].data());
|
||||
copy_runtime_info(nodes, output_values[i].get_node_shared_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -21,7 +21,9 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
op::v3::ShapeOf::ShapeOf(const Output<Node>& arg, element::Type output_type) : Op({arg}), m_output_type(output_type) {
|
||||
op::v3::ShapeOf::ShapeOf(const Output<Node>& arg, element::Type output_type)
|
||||
: ShapeOfBase({arg}),
|
||||
m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -206,7 +208,7 @@ bool op::v3::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
|
||||
}
|
||||
|
||||
// op::v0::ShapeOf
|
||||
op::v0::ShapeOf::ShapeOf(const Output<Node>& arg) : Op({arg}) {
|
||||
op::v0::ShapeOf::ShapeOf(const Output<Node>& arg) : ShapeOfBase({arg}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,10 @@
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/op/util/read_value_base.hpp"
|
||||
#include "openvino/op/util/shape_of_base.hpp"
|
||||
#include "openvino/op/util/sub_graph_base.hpp"
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -52,6 +53,7 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou
|
||||
|
||||
bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||
RUN_ON_MODEL_SCOPE(ConstantFolding);
|
||||
|
||||
bool rewritten = pre_calculated_values_folding(model);
|
||||
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
@ -76,8 +78,11 @@ bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr<ov::Model>& m
|
||||
replacement.get_node()->set_friendly_name(friendly_name_from(*node, replacements.size(), i));
|
||||
|
||||
node_output.replace(replacement);
|
||||
// Propagate runtime info attributes to replacement consumer nodes
|
||||
copy_runtime_info_to_target_inputs(node, replacement);
|
||||
// Copy runtime info from source nodes
|
||||
// when it was not propogated during pre-calculation
|
||||
copy_runtime_info_from_input_values(node);
|
||||
// Propagate runtime info attributes to replacement
|
||||
copy_runtime_info(node, replacement.get_node_shared_ptr());
|
||||
|
||||
rewritten = true;
|
||||
}
|
||||
@ -96,12 +101,17 @@ bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr<ov::Model>& m
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
void ov::pass::ConstantFolding::copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
|
||||
const Output<Node>& replacement) {
|
||||
for (auto& input : replacement.get_target_inputs()) {
|
||||
auto consumer = input.get_node()->shared_from_this();
|
||||
copy_runtime_info({node, consumer}, consumer);
|
||||
void ov::pass::ConstantFolding::copy_runtime_info_from_input_values(const std::shared_ptr<Node>& node) {
|
||||
if (is_type<op::util::ShapeOfBase>(node)) {
|
||||
// Don't propogate names of ShapeOf source node since it is not fused itself
|
||||
return;
|
||||
}
|
||||
// Add node itself to merge original rt info with rt info of inputs
|
||||
ov::NodeVector from = {node};
|
||||
for (auto& input : node->input_values()) {
|
||||
from.push_back(input.get_node_shared_ptr());
|
||||
}
|
||||
copy_runtime_info(from, node);
|
||||
}
|
||||
|
||||
bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_ptr<ov::Model>& model) {
|
||||
@ -115,14 +125,20 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_
|
||||
|
||||
if (constant_folding_is_disabled(node)) {
|
||||
can_be_folded = false;
|
||||
} else if (is_type<ov::opset1::ShapeOf>(node) || is_type<ov::opset3::ShapeOf>(node)) {
|
||||
} else if (is_type<op::util::ShapeOfBase>(node)) {
|
||||
// In case if node is ShapeOf operation we stop propagation of can_be_folded attribute. We have to limit
|
||||
// propagation because we can't detect borders of shape_of sub-graphs, so we propagate can_be_folded
|
||||
// attribute through all nodes including nodes on data path. So to limit the spread of attribute to other
|
||||
// shape-of sub-graphs we do not propagate it through ShapeOf nodes.
|
||||
can_be_folded = true;
|
||||
can_be_folded = input_values.begin()->get_partial_shape().is_static();
|
||||
} else if (op::util::is_parameter(node) || op::util::is_output(node) || op::util::is_sink(node) ||
|
||||
is_type<op::util::ReadValueBase>(node)) {
|
||||
can_be_folded = false;
|
||||
} else {
|
||||
can_be_folded = std::all_of(input_values.cbegin(), input_values.cend(), is_output_foldable);
|
||||
if (input_values.size() && can_be_folded) {
|
||||
copy_runtime_info_from_input_values(node);
|
||||
}
|
||||
}
|
||||
node->get_rt_info()["can_be_folded"] = can_be_folded;
|
||||
}
|
||||
@ -151,8 +167,8 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_
|
||||
friendly_name_from(*input_node, input_node->get_output_size(), output.get_index()));
|
||||
|
||||
output.replace(replacement);
|
||||
// Propagate runtime info attributes to replacement consumer nodes
|
||||
copy_runtime_info_to_target_inputs(input_node, replacement);
|
||||
// Propagate runtime info attributes to replacement
|
||||
copy_runtime_info(input_node, replacement);
|
||||
|
||||
rewritten = true;
|
||||
}
|
||||
|
@ -4,17 +4,18 @@
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const ov::OutputVector& outputs) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const ov::OutputVector& outputs,
|
||||
const ov::Output<ov::Node>& to) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> attrs;
|
||||
for (const auto& output : outputs) {
|
||||
for (const auto& item : output.get_rt_info()) {
|
||||
bool copy = true;
|
||||
if (item.second.is<ov::RuntimeAttribute>()) {
|
||||
copy = item.second.as<ov::RuntimeAttribute>().is_copyable();
|
||||
copy = item.second.as<ov::RuntimeAttribute>().is_copyable(to.get_node_shared_ptr());
|
||||
}
|
||||
if (copy) {
|
||||
attrs[item.first].push_back(item.second);
|
||||
@ -24,13 +25,14 @@ std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const o
|
||||
return attrs;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const ov::NodeVector& nodes) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const ov::NodeVector& nodes,
|
||||
const std::shared_ptr<ov::Node>& to) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> attrs;
|
||||
for (const auto& node : nodes) {
|
||||
for (const auto& item : node->get_rt_info()) {
|
||||
bool copy = item.first != "opset";
|
||||
if (item.second.is<ov::RuntimeAttribute>()) {
|
||||
copy = copy && item.second.as<ov::RuntimeAttribute>().is_copyable();
|
||||
copy = copy && item.second.as<ov::RuntimeAttribute>().is_copyable(to);
|
||||
}
|
||||
if (copy) {
|
||||
attrs[item.first].push_back(item.second);
|
||||
@ -41,8 +43,8 @@ std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const o
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ov::Node::RTMap mergeRuntimeInfo(const T& items) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> attrs = get_copyable_attrs(items);
|
||||
ov::Node::RTMap mergeRuntimeInfo(const std::vector<T>& items, const T& to) {
|
||||
std::unordered_map<std::string, std::vector<ov::Any>> attrs = get_copyable_attrs(items, to);
|
||||
|
||||
ov::Node::RTMap merged_attrs;
|
||||
for (auto& item : attrs) {
|
||||
@ -80,50 +82,61 @@ void assign_runtime_info(const ov::Node::RTMap& from, ov::Node::RTMap& to) {
|
||||
}
|
||||
}
|
||||
|
||||
ov::NodeVector list_with_constants(const ov::NodeVector& to) {
|
||||
ov::NodeVector ops = to;
|
||||
for (auto& node : to) {
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
for (auto& input : node->inputs()) {
|
||||
auto source_node = input.get_source_output().get_node_shared_ptr();
|
||||
if (ov::op::util::is_constant(source_node) && (0 == source_node->get_rt_info().size())) {
|
||||
if (std::find(ops.begin(), ops.end(), source_node) == ops.end()) {
|
||||
ops.push_back(source_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ops;
|
||||
}
|
||||
|
||||
ov::OutputVector list_with_constants(const ov::OutputVector& to) {
|
||||
ov::OutputVector ops = to;
|
||||
for (auto& node : to) {
|
||||
for (auto& input : node.get_node()->inputs()) {
|
||||
auto source_node = input.get_source_output();
|
||||
if (ov::op::util::is_constant(source_node.get_node_shared_ptr()) &&
|
||||
(0 == source_node.get_rt_info().size())) {
|
||||
if (std::find(ops.begin(), ops.end(), source_node) == ops.end()) {
|
||||
ops.push_back(source_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ops;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ov::copy_runtime_info(const std::shared_ptr<ov::Node>& from, const std::shared_ptr<ov::Node>& to) {
|
||||
auto& attrs = to->get_rt_info();
|
||||
auto opset = get_opset(attrs);
|
||||
|
||||
for (const auto& item : from->get_rt_info()) {
|
||||
bool copy = item.first != "opset";
|
||||
if (item.second.is<ov::RuntimeAttribute>()) {
|
||||
copy = copy && item.second.as<ov::RuntimeAttribute>().is_copyable();
|
||||
}
|
||||
if (copy) {
|
||||
attrs[item.first] = item.second;
|
||||
}
|
||||
}
|
||||
|
||||
if (!opset.empty()) {
|
||||
attrs["opset"] = opset;
|
||||
}
|
||||
return copy_runtime_info(ov::NodeVector{from}, ov::NodeVector{to});
|
||||
}
|
||||
|
||||
void ov::copy_runtime_info(const std::shared_ptr<ov::Node>& from, ov::NodeVector to) {
|
||||
for (auto& op : to) {
|
||||
copy_runtime_info(from, op);
|
||||
}
|
||||
return copy_runtime_info(ov::NodeVector{from}, to);
|
||||
}
|
||||
|
||||
void ov::copy_runtime_info(const ov::NodeVector& from, const std::shared_ptr<ov::Node>& to) {
|
||||
auto& rtInfoTo = to->get_rt_info();
|
||||
assign_runtime_info(mergeRuntimeInfo(from), rtInfoTo);
|
||||
return copy_runtime_info(from, ov::NodeVector{to});
|
||||
}
|
||||
|
||||
void ov::copy_runtime_info(const ov::NodeVector& from, ov::NodeVector to) {
|
||||
auto mergedInfo = mergeRuntimeInfo(from);
|
||||
for (auto& node : to) {
|
||||
auto& rtInfoTo = node->get_rt_info();
|
||||
assign_runtime_info(mergedInfo, rtInfoTo);
|
||||
for (auto& node : list_with_constants(to)) {
|
||||
assign_runtime_info(mergeRuntimeInfo(from, node), node->get_rt_info());
|
||||
}
|
||||
}
|
||||
|
||||
void ov::copy_output_runtime_info(const ov::OutputVector& from, ov::OutputVector to) {
|
||||
auto mergedInfo = mergeRuntimeInfo(from);
|
||||
for (auto& node : to) {
|
||||
auto& rtInfoTo = node.get_rt_info();
|
||||
assign_runtime_info(mergedInfo, rtInfoTo);
|
||||
for (auto& node : list_with_constants(to)) {
|
||||
assign_runtime_info(mergeRuntimeInfo(from, node), node.get_rt_info());
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,10 @@ bool RuntimeAttribute::is_copyable() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RuntimeAttribute::is_copyable(const std::shared_ptr<Node>& to) const {
|
||||
return is_copyable();
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const RuntimeAttribute& attrubute) {
|
||||
return os << attrubute.to_string();
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,7 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/fused_names_cleanup.hpp>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "any_copy.hpp"
|
||||
@ -33,6 +34,7 @@
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/core/runtime_attribute.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "threading/ie_executor_manager.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
@ -339,6 +341,12 @@ std::unordered_set<std::string> GetSupportedNodes(
|
||||
}
|
||||
|
||||
auto transformed_model = model->clone();
|
||||
|
||||
// Cleanup fused names if there are present in original model
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::FusedNamesCleanup>();
|
||||
m.run_passes(transformed_model);
|
||||
|
||||
transform(transformed_model);
|
||||
auto ops = transformed_model->get_ordered_ops();
|
||||
|
||||
@ -346,68 +354,80 @@ std::unordered_set<std::string> GetSupportedNodes(
|
||||
std::unordered_set<std::string> supported = GetRemovedNodes(model, transformed_model);
|
||||
std::unordered_set<std::string> unsupported;
|
||||
|
||||
auto get_names_set = [](const std::shared_ptr<ov::Node>& op) -> std::unordered_set<std::string> {
|
||||
auto fused_names = ngraph::getFusedNamesVector(op);
|
||||
std::unordered_set<std::string> names(fused_names.begin(), fused_names.end());
|
||||
names.insert(op->get_friendly_name());
|
||||
return names;
|
||||
};
|
||||
|
||||
// Collect all operation names even there are no such names in original model
|
||||
for (auto&& op : ops) {
|
||||
bool is_supported = false;
|
||||
bool is_checked = false;
|
||||
if (InferenceEngine::details::contains(original_ops, op->get_friendly_name())) {
|
||||
is_supported = is_node_supported(op);
|
||||
is_checked = true;
|
||||
if (is_supported) {
|
||||
supported.emplace(op->get_friendly_name());
|
||||
} else {
|
||||
unsupported.emplace(op->get_friendly_name());
|
||||
}
|
||||
auto names = get_names_set(op);
|
||||
if (is_node_supported(op)) {
|
||||
supported.insert(names.begin(), names.end());
|
||||
} else {
|
||||
unsupported.insert(names.begin(), names.end());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto&& fusedLayerName : ngraph::getFusedNamesVector(op)) {
|
||||
if (InferenceEngine::details::contains(original_ops, fusedLayerName)) {
|
||||
if (!is_checked) {
|
||||
is_supported = is_node_supported(op);
|
||||
is_checked = true;
|
||||
}
|
||||
if (is_supported) {
|
||||
supported.emplace(fusedLayerName);
|
||||
} else {
|
||||
unsupported.emplace(fusedLayerName);
|
||||
}
|
||||
}
|
||||
}
|
||||
// If operation was fused into several operations where one is supported
|
||||
// but another one is not supported remove it from supported
|
||||
for (auto&& name : unsupported) {
|
||||
supported.erase(name);
|
||||
}
|
||||
for (auto&& unsupportedNode : unsupported) {
|
||||
supported.erase(unsupportedNode);
|
||||
}
|
||||
for (auto&& node : model->get_ops()) {
|
||||
if (InferenceEngine::details::contains(supported, node->get_friendly_name())) {
|
||||
for (auto&& inputNodeOutput : node->input_values()) {
|
||||
if (ov::op::util::is_constant(inputNodeOutput.get_node()) ||
|
||||
ov::op::util::is_parameter(inputNodeOutput.get_node())) {
|
||||
supported.emplace(inputNodeOutput.get_node()->get_friendly_name());
|
||||
}
|
||||
}
|
||||
for (auto&& outputs : node->outputs()) {
|
||||
for (auto&& outputNodeInput : outputs.get_target_inputs()) {
|
||||
if (ov::op::util::is_output(outputNodeInput.get_node())) {
|
||||
supported.emplace(outputNodeInput.get_node()->get_friendly_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ov::op::util::is_constant(node) || ov::op::util::is_parameter(node)) {
|
||||
if (node->output(0).get_target_inputs().size() &&
|
||||
!InferenceEngine::details::contains(
|
||||
supported,
|
||||
node->output(0).get_target_inputs().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
auto has_all_consumers_unsupported = [&supported](const std::shared_ptr<ov::Node>& node) {
|
||||
for (auto&& input : node->output(0).get_target_inputs()) {
|
||||
if (details::contains(supported, input.get_node()->get_friendly_name())) {
|
||||
return false;
|
||||
}
|
||||
} else if (ov::op::util::is_output(node)) {
|
||||
if (!InferenceEngine::details::contains(supported,
|
||||
node->input_values().begin()->get_node()->get_friendly_name())) {
|
||||
supported.erase(node->get_friendly_name());
|
||||
}
|
||||
return (node->output(0).get_target_inputs().size() != 0);
|
||||
};
|
||||
|
||||
auto has_unsupported_source = [&supported](const std::shared_ptr<ov::Node>& node) {
|
||||
return !details::contains(supported, node->input_values().begin()->get_node()->get_friendly_name());
|
||||
};
|
||||
|
||||
// Walk over transformed model for special handing of Parameters/Constants/Results
|
||||
for (auto&& op : ops) {
|
||||
// Mark Constants and all fused names as unsupported if they are have no
|
||||
// supported consumers/sources
|
||||
if (ov::op::util::is_constant(op)) {
|
||||
if (has_all_consumers_unsupported(op)) {
|
||||
auto names = get_names_set(op);
|
||||
for (auto& name : get_names_set(op)) {
|
||||
supported.erase(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return supported;
|
||||
|
||||
// Finally get intersection of all supported operation names
|
||||
// and operation names from original model
|
||||
std::unordered_set<std::string> res;
|
||||
for (auto&& name : supported) {
|
||||
if (details::contains(original_ops, name)) {
|
||||
res.insert(name);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove parameters which has no supported consumers
|
||||
for (auto& param : model->get_parameters()) {
|
||||
if (has_all_consumers_unsupported(param)) {
|
||||
res.erase(param->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove results which has no supported source node
|
||||
for (auto& result : model->get_results()) {
|
||||
if (has_unsupported_source(result)) {
|
||||
res.erase(result->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void SetExeNetworkInfo(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork,
|
||||
|
420
src/inference/tests/unit/query_model_test.cpp
Normal file
420
src/inference/tests/unit/query_model_test.cpp
Normal file
@ -0,0 +1,420 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
|
||||
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/pass/constant_folding.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/common_optimizations/common_optimizations.hpp"
|
||||
#include "transformations/common_optimizations/nop_elimination.hpp"
|
||||
#include "transformations/convert_precision.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::unordered_set<std::string>& s) {
|
||||
for (auto it = s.begin(); it != s.end(); ++it) {
|
||||
if (it != s.begin()) {
|
||||
os << ", " << *it;
|
||||
} else {
|
||||
os << *it;
|
||||
}
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
class GetSupportedNodesTest : public ::testing::Test {
|
||||
protected:
|
||||
ov::Shape m_shape{1, 84};
|
||||
std::shared_ptr<ov::Model> m_function;
|
||||
|
||||
public:
|
||||
void Run(std::function<void(std::shared_ptr<ov::Model>&)> transform,
|
||||
std::function<bool(const std::shared_ptr<ngraph::Node>)> is_node_supported,
|
||||
const std::unordered_set<std::string>& expected) {
|
||||
auto supported = InferenceEngine::GetSupportedNodes(m_function, transform, is_node_supported);
|
||||
auto const is_in_expected = [&expected](const std::string& x) {
|
||||
return expected.find(x) != expected.end();
|
||||
};
|
||||
bool is_equal =
|
||||
(supported.size() == expected.size()) && std::all_of(supported.begin(), supported.end(), is_in_expected);
|
||||
std::stringstream ss;
|
||||
if (!is_equal) {
|
||||
ss << "Expected list of supported nodes '" << expected << "' but actually received '" << supported << "'";
|
||||
}
|
||||
ASSERT_TRUE(is_equal) << ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GetSupportedNodesTest, UnsupportedCompressedConstantCF) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto constant_compressed = ngraph::op::Constant::create(ov::element::f16, m_shape, {1});
|
||||
constant_compressed->set_friendly_name("constant_compressed");
|
||||
auto convert = std::make_shared<ov::opset9::Convert>(constant_compressed, ov::element::f32);
|
||||
convert->set_friendly_name("constant");
|
||||
ov::mark_as_decompression(convert);
|
||||
auto add = std::make_shared<ov::opset9::Add>(param, convert);
|
||||
add->set_friendly_name("add");
|
||||
auto result = std::make_shared<ngraph::op::Result>(add);
|
||||
result->set_friendly_name("result");
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConstantFolding>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op);
|
||||
},
|
||||
{});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, ConstantSubgraphCF) {
|
||||
{
|
||||
auto constant_compressed1 = ngraph::op::Constant::create(ov::element::f16, m_shape, {1});
|
||||
constant_compressed1->set_friendly_name("constant_compressed1");
|
||||
auto convert1 = std::make_shared<ov::opset9::Convert>(constant_compressed1, ov::element::f32);
|
||||
convert1->set_friendly_name("constant1");
|
||||
ov::mark_as_decompression(convert1);
|
||||
auto constant_compressed2 = ngraph::op::Constant::create(ov::element::f16, m_shape, {2});
|
||||
constant_compressed2->set_friendly_name("constant_compressed2");
|
||||
auto convert2 = std::make_shared<ov::opset9::Convert>(constant_compressed2, ov::element::f32);
|
||||
convert2->set_friendly_name("constant2");
|
||||
ov::mark_as_decompression(convert2);
|
||||
auto add = std::make_shared<ov::opset9::Add>(convert1, convert2);
|
||||
add->set_friendly_name("add");
|
||||
auto const_reshape = ov::opset9::Constant::create(ngraph::element::i64, ov::Shape{1}, {84});
|
||||
const_reshape->set_friendly_name("const_reshape");
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(add, const_reshape, false);
|
||||
reshape->set_friendly_name("reshape");
|
||||
auto result = std::make_shared<ngraph::op::Result>(reshape);
|
||||
result->set_friendly_name("result");
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConstantFolding>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op);
|
||||
},
|
||||
{"constant_compressed1",
|
||||
"constant1",
|
||||
"constant_compressed2",
|
||||
"constant2",
|
||||
"add",
|
||||
"const_reshape",
|
||||
"reshape",
|
||||
"result"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, SupportedCompressedConstantNop) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto constant_compressed = ngraph::op::Constant::create(ov::element::f16, m_shape, {1});
|
||||
constant_compressed->set_friendly_name("constant_compressed");
|
||||
auto convert = std::make_shared<ov::opset9::Convert>(constant_compressed, ov::element::f32);
|
||||
convert->set_friendly_name("constant");
|
||||
auto add = std::make_shared<ov::opset9::Add>(param, convert);
|
||||
add->set_friendly_name("add");
|
||||
auto result = std::make_shared<ngraph::op::Result>(add);
|
||||
result->set_friendly_name("result");
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertPrecision>(
|
||||
precisions_array{{ngraph::element::f16, ngraph::element::f32}});
|
||||
m.register_pass<ngraph::pass::NopElimination>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Add>(op) != nullptr);
|
||||
},
|
||||
{"input", "constant_compressed", "constant", "add", "result"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, SupportedConstantInsertAdditionalOp) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto mul_const = ngraph::op::Constant::create(ov::element::f32, m_shape, {1});
|
||||
mul_const->set_friendly_name("constant");
|
||||
auto mul = std::make_shared<ov::opset9::Multiply>(param, mul_const);
|
||||
mul->set_friendly_name("output_operation");
|
||||
auto result = std::make_shared<ngraph::op::Result>(mul);
|
||||
result->set_friendly_name("result");
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.run_passes(model);
|
||||
for (auto& op : model->get_ops()) {
|
||||
if (std::dynamic_pointer_cast<ov::opset9::Multiply>(op) != nullptr) {
|
||||
// Add one more dummy operation
|
||||
auto consumers = op->output(0).get_target_inputs();
|
||||
auto shape = op->get_shape();
|
||||
auto add_const = ngraph::op::Constant::create(ov::element::f32, m_shape, {0});
|
||||
auto add = std::make_shared<ov::opset9::Add>(op, add_const);
|
||||
add->set_friendly_name(op->get_friendly_name());
|
||||
op->set_friendly_name(op->get_friendly_name() + "/previous");
|
||||
ov::copy_runtime_info(op, add);
|
||||
for (auto& consumer : consumers) {
|
||||
consumer.replace_source_output(add);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Multiply>(op) != nullptr) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Add>(op) != nullptr);
|
||||
},
|
||||
{"input", "constant", "output_operation", "result"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, PartiallySupportedCompressedConstant) {
|
||||
{
|
||||
auto param1 = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param1->set_friendly_name("input1");
|
||||
auto param2 = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param2->set_friendly_name("input2");
|
||||
auto constant_compressed = ngraph::op::Constant::create(ov::element::f16, m_shape, {1});
|
||||
constant_compressed->set_friendly_name("constant_compressed");
|
||||
auto convert = std::make_shared<ov::opset9::Convert>(constant_compressed, ov::element::f32);
|
||||
convert->set_friendly_name("constant");
|
||||
ov::mark_as_decompression(convert);
|
||||
auto add = std::make_shared<ov::opset9::Add>(param1, convert);
|
||||
add->set_friendly_name("add");
|
||||
auto result1 = std::make_shared<ngraph::op::Result>(add);
|
||||
result1->set_friendly_name("result1");
|
||||
auto mul = std::make_shared<ov::opset9::Multiply>(param2, convert);
|
||||
mul->set_friendly_name("mul");
|
||||
auto result2 = std::make_shared<ngraph::op::Result>(mul);
|
||||
result2->set_friendly_name("result2");
|
||||
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result1, result2},
|
||||
ngraph::ParameterVector{param1, param2});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConstantFolding>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Multiply>(op) != nullptr);
|
||||
},
|
||||
{"input2", "constant_compressed", "constant", "mul", "result2"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, ConstantSubgraphSupported) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto weights = ov::opset9::Constant::create(ov::element::Type_t::f32, {10, 84}, {1});
|
||||
weights->set_friendly_name("weights");
|
||||
auto shapeOf = std::make_shared<ov::opset9::ShapeOf>(weights);
|
||||
shapeOf->set_friendly_name("shapeof");
|
||||
auto const1 = ov::opset9::Constant::create(ov::element::Type_t::i32, {1}, {1});
|
||||
const1->set_friendly_name("const1");
|
||||
auto const2 = ov::opset9::Constant::create(ov::element::Type_t::i64, {}, {0});
|
||||
const2->set_friendly_name("const2");
|
||||
auto gather = std::make_shared<ov::opset9::Gather>(shapeOf, const1, const2);
|
||||
gather->set_friendly_name("gather");
|
||||
auto const3 = ov::opset9::Constant::create(ov::element::Type_t::i64, {1}, {1});
|
||||
const3->set_friendly_name("const3");
|
||||
auto concat = std::make_shared<ov::opset9::Concat>(ov::NodeVector{const3, gather}, 0);
|
||||
concat->set_friendly_name("concat");
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(param, concat, false);
|
||||
reshape->set_friendly_name("reshape");
|
||||
auto matmul = std::make_shared<ov::opset9::MatMul>(reshape, weights, false, true);
|
||||
matmul->set_friendly_name("matmul");
|
||||
auto result = std::make_shared<ngraph::op::Result>(matmul);
|
||||
result->set_friendly_name("result");
|
||||
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConstantFolding>();
|
||||
m.register_pass<ngraph::pass::NopElimination>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::MatMul>(op) != nullptr);
|
||||
},
|
||||
{"input",
|
||||
"weights",
|
||||
"shapeof",
|
||||
"const1",
|
||||
"const2",
|
||||
"gather",
|
||||
"const3",
|
||||
"concat",
|
||||
"reshape",
|
||||
"matmul",
|
||||
"result"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, UnmarkedSupportedInputsOutputs) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto constant = ngraph::op::Constant::create(ov::element::f32, ov::Shape{m_shape[1]}, {1});
|
||||
constant->set_friendly_name("constant");
|
||||
auto const_reshape = ov::opset9::Constant::create(ngraph::element::i64, ov::Shape{2}, m_shape);
|
||||
const_reshape->set_friendly_name("const_reshape");
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(constant, const_reshape, false);
|
||||
reshape->set_friendly_name("reshape");
|
||||
auto add = std::make_shared<ov::opset9::Add>(param, reshape);
|
||||
add->set_friendly_name("add");
|
||||
auto result = std::make_shared<ngraph::op::Result>(add);
|
||||
result->set_friendly_name("result");
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConstantFolding>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
// Plugin don't mark input, constant and result as supported
|
||||
return (std::dynamic_pointer_cast<ov::opset9::Add>(op) != nullptr);
|
||||
},
|
||||
{"add"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, WrongFusedNamesInOriginalModel) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto weights = ov::opset9::Constant::create(ov::element::Type_t::f32, {10, 84}, {1});
|
||||
weights->set_friendly_name("weights");
|
||||
auto matmul = std::make_shared<ov::opset9::MatMul>(param, weights, false, true);
|
||||
matmul->get_rt_info()[ngraph::FusedNames::get_type_info_static()] = ngraph::FusedNames("add");
|
||||
matmul->set_friendly_name("matmul");
|
||||
auto constant = ngraph::op::Constant::create(ov::element::f32, {1, 10}, {1});
|
||||
constant->set_friendly_name("constant");
|
||||
auto add = std::make_shared<ov::opset9::Add>(matmul, constant);
|
||||
add->get_rt_info()[ngraph::FusedNames::get_type_info_static()] = ngraph::FusedNames("matmul");
|
||||
add->set_friendly_name("add");
|
||||
auto result = std::make_shared<ngraph::op::Result>(add);
|
||||
result->set_friendly_name("result");
|
||||
|
||||
m_function = std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
return;
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::MatMul>(op) != nullptr);
|
||||
},
|
||||
{"input", "weights", "matmul"});
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, FusedNamesSupportedUnsupportedBoth) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto dummy_param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
dummy_param->set_friendly_name("dummy_param");
|
||||
auto logsoftmax = std::make_shared<ov::opset9::LogSoftmax>(param, 1);
|
||||
logsoftmax->set_friendly_name("logsoftmax");
|
||||
auto result = std::make_shared<ngraph::op::Result>(logsoftmax);
|
||||
result->set_friendly_name("result");
|
||||
m_function =
|
||||
std::make_shared<ov::Model>(ngraph::ResultVector{result}, ngraph::ParameterVector{param, dummy_param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
// Exp is not supported and all constants are missing
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::ReduceMax>(op) != nullptr) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Subtract>(op) != nullptr) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::ReduceSum>(op) != nullptr) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::Log>(op) != nullptr);
|
||||
},
|
||||
{"dummy_param"}); // kepp dummy only since it has no unsupported consumers
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, ShapeOfNonConstantNode) {
|
||||
{
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ov::element::f32, m_shape);
|
||||
param->set_friendly_name("input");
|
||||
auto slope_compressed = ov::opset9::Constant::create(ngraph::element::f16, ngraph::Shape{}, {-2.f});
|
||||
slope_compressed->set_friendly_name("slope_compressed");
|
||||
auto convert_slope = std::make_shared<ov::opset9::Convert>(slope_compressed, ov::element::f32);
|
||||
convert_slope->set_friendly_name("slope");
|
||||
ov::mark_as_decompression(convert_slope);
|
||||
auto prelu = std::make_shared<ov::opset9::PRelu>(param, convert_slope);
|
||||
prelu->set_friendly_name("prelu");
|
||||
auto shapeOf = std::make_shared<ov::opset9::ShapeOf>(prelu);
|
||||
shapeOf->set_friendly_name("shapeof");
|
||||
auto convert_fp32 = std::make_shared<ov::opset9::Convert>(shapeOf, ov::element::f32);
|
||||
convert_fp32->set_friendly_name("convert_fp32");
|
||||
auto scale = ov::opset9::Constant::create(ngraph::element::f32, ngraph::Shape{}, {2.0f});
|
||||
scale->set_friendly_name("scale");
|
||||
auto mul_scale = std::make_shared<ov::opset9::Multiply>(convert_fp32, scale);
|
||||
mul_scale->set_friendly_name("mul_scale");
|
||||
auto convert_i64 = std::make_shared<ov::opset9::Convert>(mul_scale, ov::element::i64);
|
||||
convert_i64->set_friendly_name("convert_i64");
|
||||
auto interpolate = std::make_shared<ov::opset9::Interpolate>(prelu,
|
||||
convert_i64,
|
||||
scale,
|
||||
ov::opset9::Interpolate::InterpolateAttrs());
|
||||
interpolate->set_friendly_name("interpolate");
|
||||
auto interpolate_result = std::make_shared<ngraph::op::Result>(interpolate);
|
||||
interpolate_result->set_friendly_name("interpolate_result");
|
||||
m_function =
|
||||
std::make_shared<ov::Model>(ngraph::ResultVector{interpolate_result}, ngraph::ParameterVector{param});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::CommonOptimizations>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_constant(op) || ov::op::util::is_output(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset9::PRelu>(op) != nullptr);
|
||||
},
|
||||
{"input", "slope_compressed", "slope", "prelu"}); // keep dummy only since it has no unsupported consumers
|
||||
}
|
@ -102,8 +102,9 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
* sequence starting from 0 and replace last two dimension. For example for length = 4 the
|
||||
* order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
|
||||
*/
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
auto create_transpose = [this](const ngraph::Output<ngraph::Node>& node, const std::string& transpose_name) {
|
||||
auto create_transpose = [this, &new_ops ](const ngraph::Output<ngraph::Node>& node, const std::string& transpose_name) {
|
||||
auto rank = node.get_partial_shape().rank();
|
||||
std::vector<size_t> transpose_order(rank.get_length());
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
@ -112,13 +113,14 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order);
|
||||
auto transpose = ngraph::op::util::make_try_fold<ngraph::opset1::Transpose>(node, transpose_const);
|
||||
if (!ngraph::is_type<ngraph::opset1::Constant>(transpose)) {
|
||||
new_ops.push_back(transpose_const);
|
||||
MatcherPass::register_new_node(transpose);
|
||||
}
|
||||
transpose->set_friendly_name(transpose_name);
|
||||
new_ops.push_back(transpose);
|
||||
return transpose;
|
||||
};
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
bool success = true;
|
||||
ngraph::PartialShape shape_a_aligned, shape_b_aligned;
|
||||
std::tie(success, shape_a_aligned, shape_b_aligned) = get_aligned_shapes();
|
||||
@ -137,7 +139,6 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
// Weights normalization
|
||||
if (!matmul->get_transpose_b()) {
|
||||
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (rank_b != 2) {
|
||||
@ -146,13 +147,15 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
|
||||
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
|
||||
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
|
||||
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr())) {
|
||||
new_ops.push_back(reshape_shape);
|
||||
}
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
// Input normalization
|
||||
if (matmul->get_transpose_a() && rank_a != 1) {
|
||||
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto output_rank = matmul->get_output_partial_shape(0).rank();
|
||||
|
@ -89,6 +89,7 @@ ov::intel_cpu::MoveEltwiseUpThroughDataMov::MoveEltwiseUpThroughDataMov() {
|
||||
if (is_binary_op && current->get_output_partial_shape(0).rank().get_length() != eltwise->get_input_partial_shape(1).rank().get_length()) {
|
||||
auto old_eltwise_const = std::dynamic_pointer_cast<ngraph::opset8::Constant>(eltwise->get_input_node_shared_ptr(1));
|
||||
auto new_constant = std::make_shared<ngraph::opset8::Constant>(*old_eltwise_const.get(), ngraph::Shape{});
|
||||
ngraph::copy_runtime_info(old_eltwise_const, new_constant);
|
||||
ngraph::replace_node(old_eltwise_const, new_constant);
|
||||
}
|
||||
ngraph::replace_output_update_name(eltwise->output(0), eltwise->input_value(0));
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
@ -25,6 +26,9 @@ public:
|
||||
MemoryFormats() = default;
|
||||
explicit MemoryFormats(const std::string &_memory_format) : memory_format(_memory_format) {}
|
||||
std::string getMemoryFormats() const { return memory_format; }
|
||||
bool is_copyable(const std::shared_ptr<ov::Node>& to) const override {
|
||||
return (!ov::op::util::is_constant(to));
|
||||
}
|
||||
|
||||
ov::Any merge(const ngraph::NodeVector & nodes) const override {
|
||||
std::set<std::string> unique_mem_format;
|
||||
|
@ -25,15 +25,18 @@ void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
|
||||
ngraph::Shape{shape.size()}, shape);
|
||||
auto reshape_node = std::make_shared<ngraph::opset8::Reshape>(transpose_node->input_value(0), reshape_const, false);
|
||||
reshape_node->set_friendly_name(transpose_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(transpose_node, reshape_node);
|
||||
ngraph::copy_runtime_info(transpose_node, {reshape_node, reshape_const });
|
||||
transpose_node->output(0).replace(reshape_node->output(0));
|
||||
}
|
||||
|
||||
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name, bool before_matmul) {
|
||||
auto create_reshape = [](const ngraph::Shape& shape, std::shared_ptr<ngraph::Node> input_node, const std::string& name) {
|
||||
ngraph::NodeVector new_ops;
|
||||
auto create_reshape = [&new_ops](const ngraph::Shape& shape, std::shared_ptr<ngraph::Node> input_node, const std::string& name) {
|
||||
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{shape.size()}, shape);
|
||||
new_ops.push_back(reshape_const);
|
||||
auto node = std::make_shared<ngraph::opset8::Reshape>(input_node, reshape_const, false);
|
||||
new_ops.push_back(node);
|
||||
node->set_friendly_name(name);
|
||||
return node;
|
||||
};
|
||||
@ -51,23 +54,21 @@ void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string&
|
||||
std::iota(std::begin(permute_order), std::end(permute_order), 0);
|
||||
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
std::shared_ptr<ngraph::Node> node = prev_node;
|
||||
if (!before_matmul) {
|
||||
auto shape = prev_node->get_output_shape(0);
|
||||
std::swap(shape[0], shape[1]);
|
||||
node = create_reshape(shape, node, base_name + "/reshape_before_transpose");
|
||||
new_ops.push_back(node);
|
||||
}
|
||||
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
|
||||
new_ops.push_back(transpose_order);
|
||||
node = std::make_shared<ngraph::opset8::Transpose>(node, transpose_order);
|
||||
node->set_friendly_name(base_name + "/in_transpose");
|
||||
new_ops.push_back(node);
|
||||
|
||||
if (before_matmul) {
|
||||
node = create_reshape(orig_shape, node, base_name + "/reshape_after_transpose");
|
||||
new_ops.push_back(node);
|
||||
}
|
||||
|
||||
ngraph::copy_runtime_info(prev_node, new_ops);
|
||||
|
@ -80,10 +80,11 @@ static bool InsertReshape(
|
||||
bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2;
|
||||
if (need_reshape_before) {
|
||||
std::vector<int> before_shape = {-1, static_cast<int>(first_node->get_output_shape(0).back())};
|
||||
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node,
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false);
|
||||
auto reshape_before_node_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{before_shape.size()}, before_shape);
|
||||
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node, reshape_before_node_const, false);
|
||||
reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul");
|
||||
ngraph::copy_runtime_info(first_node, reshape_before_node);
|
||||
ngraph::copy_runtime_info(first_node, { reshape_before_node, reshape_before_node_const });
|
||||
matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0));
|
||||
if (auto transpose_node = std::dynamic_pointer_cast<ngraph::opset8::Transpose>(nodes.back())) {
|
||||
nodes.pop_back();
|
||||
@ -103,11 +104,11 @@ static bool InsertReshape(
|
||||
<< " For this reason, there is no way to determine permutation shape.";
|
||||
}
|
||||
std::vector<int> permutation_shape = {1, 0};
|
||||
auto transpose_node_copy_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{permutation_shape.size()}, permutation_shape);
|
||||
auto transpose_node_copy = transpose_node->clone_with_new_inputs(
|
||||
{transpose_node->input_values()[0],
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{permutation_shape.size()}, permutation_shape)});
|
||||
ngraph::copy_runtime_info(transpose_node, transpose_node_copy);
|
||||
{transpose_node->input_values()[0], transpose_node_copy_const });
|
||||
ngraph::copy_runtime_info(transpose_node, {transpose_node_copy, transpose_node_copy_const});
|
||||
ngraph::replace_node(transpose_node, transpose_node_copy);
|
||||
nodes.push_back(transpose_node_copy);
|
||||
}
|
||||
@ -124,11 +125,11 @@ static bool InsertReshape(
|
||||
}
|
||||
|
||||
if (need_reshape_after) {
|
||||
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(nodes.back(),
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
|
||||
auto reshape_after_node_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{last_node_shape.size()}, last_node_shape);
|
||||
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(nodes.back(), reshape_after_node_const, false);
|
||||
reshape_after_node->set_friendly_name(nodes.back()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(nodes.back(), reshape_after_node);
|
||||
ngraph::copy_runtime_info(nodes.back(), { reshape_after_node, reshape_after_node_const});
|
||||
for (auto consumer : consumers) {
|
||||
consumer.replace_source_output(reshape_after_node);
|
||||
}
|
||||
|
@ -96,13 +96,13 @@ bool InsertTransposeAfterConvOrPool::run_on_model(const std::shared_ptr<ngraph::
|
||||
transposeInShape);
|
||||
auto reshapeBefore = std::make_shared<ngraph::opset7::Reshape>(node, reshapeConstBefore, false);
|
||||
reshapeBefore->set_friendly_name(node->get_friendly_name() + "/reshape_out");
|
||||
ngraph::copy_runtime_info(node, reshapeBefore);
|
||||
ngraph::copy_runtime_info(node, {reshapeBefore, reshapeConstBefore});
|
||||
|
||||
auto transpose_order = transposeInShape.size() == 3 ? ngraph::Shape{0, 2, 1} : ngraph::Shape{0, 3, 1, 2};
|
||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(reshapeBefore,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{transpose_order.size()}, transpose_order));
|
||||
auto transpose_order_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(reshapeBefore, transpose_order_const);
|
||||
transpose->set_friendly_name(node->get_friendly_name() + "/transpose_out");
|
||||
ngraph::copy_runtime_info(node, transpose);
|
||||
ngraph::copy_runtime_info(node, {transpose, transpose_order_const});
|
||||
|
||||
for (auto& input : consumers) {
|
||||
input.replace_source_output(transpose);
|
||||
|
@ -450,7 +450,7 @@ bool transform_to_pwl(
|
||||
m_constant, b_constant, alpha_constant);
|
||||
pwl->set_base_node(node);
|
||||
pwl->set_friendly_name(node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(node, pwl);
|
||||
ngraph::copy_runtime_info(node, {pwl, m_constant, b_constant, alpha_constant});
|
||||
replace_node(node, pwl);
|
||||
return true;
|
||||
}
|
||||
|
@ -29,22 +29,24 @@ static void SwapAndTransposeInputs(
|
||||
std::shared_ptr<ngraph::Node> fq = nullptr,
|
||||
std::shared_ptr<ngraph::Node> act = nullptr,
|
||||
std::shared_ptr<ngraph::Node> transpose = nullptr) {
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
auto create_transpose =
|
||||
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||
[&new_ops](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
||||
|
||||
std::vector<size_t> transpose_order(output_shape.size());
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
|
||||
|
||||
auto transpose = std::make_shared<ngraph::opset8::Transpose>(
|
||||
node, ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
|
||||
auto transpose_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order);
|
||||
new_ops.push_back(transpose_const);
|
||||
auto transpose = std::make_shared<ngraph::opset8::Transpose>(node, transpose_const);
|
||||
transpose->set_friendly_name(transpose_name);
|
||||
new_ops.push_back(transpose);
|
||||
return transpose;
|
||||
};
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
auto transpose_matmul_input = [matmul_node, &new_ops, create_transpose](size_t ix) {
|
||||
std::shared_ptr<ngraph::Node> matmul_input = matmul_node->input_value(ix).get_node_shared_ptr();
|
||||
auto input_transpose = std::dynamic_pointer_cast<ngraph::opset8::Transpose>(matmul_input);
|
||||
@ -53,7 +55,6 @@ static void SwapAndTransposeInputs(
|
||||
ngraph::replace_output_update_name(input_transpose->output(0), input_transpose->input_value(0));
|
||||
} else {
|
||||
matmul_input = create_transpose(matmul_node->input_value(ix), matmul_node->get_friendly_name() + "/input_transpose");
|
||||
new_ops.push_back(matmul_input);
|
||||
}
|
||||
return matmul_input;
|
||||
};
|
||||
@ -90,18 +91,17 @@ static void SwapAndTransposeInputs(
|
||||
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
|
||||
if (bias->get_output_shape(0).size() > 1) {
|
||||
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
|
||||
new_ops.push_back(bias);
|
||||
|
||||
auto transpose_shape = bias->get_output_shape(0);
|
||||
auto matmul_shape = matmul_node->get_output_shape(0);
|
||||
if (transpose_shape.size() > matmul_shape.size()) {
|
||||
std::vector<size_t> reshape_shape(matmul_shape.size(), 1);
|
||||
std::copy_if(transpose_shape.begin(), transpose_shape.end(), reshape_shape.begin(), [](size_t e) { return e > 1; });
|
||||
bias = std::make_shared<ngraph::opset8::Reshape>(bias,
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{reshape_shape.size()}, reshape_shape), false);
|
||||
auto bias_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||
bias = std::make_shared<ngraph::opset8::Reshape>(bias, bias_const, false);
|
||||
bias->set_friendly_name(add->get_friendly_name() + "/reshape");
|
||||
ngraph::copy_runtime_info(add, bias);
|
||||
ngraph::copy_runtime_info(add, {bias, bias_const});
|
||||
new_ops.push_back(bias);
|
||||
}
|
||||
}
|
||||
@ -126,7 +126,6 @@ static void SwapAndTransposeInputs(
|
||||
|
||||
if (transpose == nullptr) {
|
||||
new_node = create_transpose(new_node, last_layer_name);
|
||||
new_ops.push_back(new_node);
|
||||
} else {
|
||||
ngraph::replace_output_update_name(transpose->output(0), transpose->input_value(0));
|
||||
new_node->set_friendly_name(last_layer_name);
|
||||
|
@ -837,10 +837,8 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
|
||||
static const std::vector<std::string> attrs_to_check{"fused_names_0"};
|
||||
|
||||
std::ostringstream err_log;
|
||||
for (auto& op : f->get_ops()) {
|
||||
if (ov::op::util::is_constant(op))
|
||||
continue;
|
||||
|
||||
for (auto& op : f->get_ops()) {
|
||||
const auto& rt_info = op->get_rt_info();
|
||||
for (const auto& attr_name : attrs_to_check) {
|
||||
if (!rt_info.count(attr_name)) {
|
||||
|
Loading…
Reference in New Issue
Block a user