Resolve comments 12736 (#12778)
* Comments resolving * Style and getting rid of asserts * style * Update src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp
This commit is contained in:
parent
7601400d99
commit
79f1e720e7
@ -5,16 +5,18 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace ngraph {
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class NGRAPH_API LSTMStatesBroadcast;
|
class TRANSFORMATIONS_API LSTMStatesBroadcast;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
@ -22,8 +24,8 @@ class NGRAPH_API LSTMStatesBroadcast;
|
|||||||
* we make them broadcast-able by batch
|
* we make them broadcast-able by batch
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass {
|
class ov::pass::LSTMStatesBroadcast : public ov::pass::ModelPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
|
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
|
||||||
bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
|
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||||
};
|
};
|
||||||
|
@ -5,16 +5,18 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace ngraph {
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class NGRAPH_API ReshapeSinkingMatMul;
|
class TRANSFORMATIONS_API ReshapeSinkingMatMul;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
@ -24,7 +26,7 @@ class NGRAPH_API ReshapeSinkingMatMul;
|
|||||||
* Reshape operators to make batch propagate through freely
|
* Reshape operators to make batch propagate through freely
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass {
|
class ov::pass::ReshapeSinkingMatMul : public ov::pass::MatcherPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
|
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
|
||||||
ReshapeSinkingMatMul();
|
ReshapeSinkingMatMul();
|
||||||
|
@ -118,10 +118,10 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
|
|||||||
manager.register_pass<ngraph::pass::Validate>();
|
manager.register_pass<ngraph::pass::Validate>();
|
||||||
|
|
||||||
if (!m_use_shapes) { // Approved Smart Reshape
|
if (!m_use_shapes) { // Approved Smart Reshape
|
||||||
manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
|
manager.register_pass<ov::pass::LSTMStatesBroadcast>();
|
||||||
manager.register_pass<ngraph::pass::Validate>();
|
manager.register_pass<ov::pass::Validate>();
|
||||||
manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
|
manager.register_pass<ov::pass::ReshapeSinkingMatMul>();
|
||||||
manager.register_pass<ngraph::pass::Validate>();
|
manager.register_pass<ov::pass::Validate>();
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||||
|
@ -2,26 +2,23 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "transformations/smart_reshape/lstm_states_broadcast.hpp"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/pass/manager.hpp>
|
|
||||||
#include <openvino/op/util/sub_graph_base.hpp>
|
|
||||||
#include <openvino/opsets/opset9.hpp>
|
|
||||||
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
|
|
||||||
#include <transformations/utils/utils.hpp>
|
|
||||||
|
|
||||||
#include "dimension_tracker.hpp"
|
#include "dimension_tracker.hpp"
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/sub_graph_base.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/manager.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov::opset9::Parameter>& parameter,
|
using namespace std;
|
||||||
const std::shared_ptr<ov::opset9::TensorIterator>& ti) {
|
using namespace ov::opset9;
|
||||||
const auto& body = ti->get_body();
|
|
||||||
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
|
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const shared_ptr<Parameter>& parameter,
|
||||||
|
const shared_ptr<TensorIterator>& ti) {
|
||||||
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
|
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
|
||||||
OPENVINO_ASSERT(parameter_index >= 0,
|
|
||||||
"LSTMStatesBroadcast encountered unregistered parameter ",
|
|
||||||
parameter,
|
|
||||||
" related to TI body ",
|
|
||||||
ti);
|
|
||||||
for (const auto& input_descriptor : ti->get_input_descriptions())
|
for (const auto& input_descriptor : ti->get_input_descriptions())
|
||||||
if (input_descriptor->m_body_parameter_index == parameter_index)
|
if (input_descriptor->m_body_parameter_index == parameter_index)
|
||||||
return ti->input(input_descriptor->m_input_index);
|
return ti->input(input_descriptor->m_input_index);
|
||||||
@ -31,13 +28,11 @@ ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov:
|
|||||||
parameter);
|
parameter);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(const shared_ptr<TensorIterator>& ti,
|
||||||
const std::shared_ptr<ov::opset9::TensorIterator>& ti,
|
const shared_ptr<LSTMCell>& lstm_cell) {
|
||||||
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
const auto& body = ti->get_body(); // body is not nullptr -- we checked earlier
|
||||||
const auto& body = ti->get_body();
|
|
||||||
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
|
|
||||||
|
|
||||||
std::map<ov::opset9::Parameter*, ov::PartialShape> original_shapes;
|
map<Parameter*, ov::PartialShape> original_shapes;
|
||||||
size_t label = 1;
|
size_t label = 1;
|
||||||
|
|
||||||
// mark all input dimensions with labels and making them dynamic, keeping original shapes
|
// mark all input dimensions with labels and making them dynamic, keeping original shapes
|
||||||
@ -46,9 +41,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
|||||||
original_shapes[parameter.get()] = pshape;
|
original_shapes[parameter.get()] = pshape;
|
||||||
if (pshape.rank().is_dynamic())
|
if (pshape.rank().is_dynamic())
|
||||||
continue;
|
continue;
|
||||||
for (ngraph::Dimension& n : pshape) {
|
for (ov::Dimension& n : pshape) {
|
||||||
OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0,
|
|
||||||
"LSTMStatesBroadcast encountered TI with previously tracked dimensions");
|
|
||||||
n = ov::Dimension::dynamic();
|
n = ov::Dimension::dynamic();
|
||||||
ov::DimensionTracker::set_label(n, label++);
|
ov::DimensionTracker::set_label(n, label++);
|
||||||
}
|
}
|
||||||
@ -68,7 +61,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// batch label was tracked -- finding parameter that delivered it
|
// batch label was tracked -- finding parameter that delivered it
|
||||||
std::shared_ptr<ov::opset9::Parameter> batch_delivering_parameter;
|
shared_ptr<Parameter> batch_delivering_parameter;
|
||||||
size_t index_of_batch_dim = 0;
|
size_t index_of_batch_dim = 0;
|
||||||
|
|
||||||
size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
|
size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
|
||||||
@ -80,8 +73,11 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
|||||||
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
|
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
|
||||||
batch_delivering_parameter = parameter;
|
batch_delivering_parameter = parameter;
|
||||||
index_of_batch_dim = i;
|
index_of_batch_dim = i;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (index_of_batch_dim != 0 && batch_delivering_parameter != nullptr)
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
for (auto& item : original_shapes)
|
for (auto& item : original_shapes)
|
||||||
item.first->set_partial_shape(item.second);
|
item.first->set_partial_shape(item.second);
|
||||||
@ -91,89 +87,83 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
|
const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
|
||||||
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(batched_source.get_source_output());
|
const auto& batched_shape = make_shared<ShapeOf>(batched_source.get_source_output());
|
||||||
const auto& batch = std::make_shared<ov::opset9::Gather>(
|
const auto& batch = make_shared<Gather>(batched_shape,
|
||||||
batched_shape,
|
Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
|
Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const std::shared_ptr<ov::Node>& batch_delivering_node) {
|
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const shared_ptr<ov::Node>& batch_delivering_node) {
|
||||||
auto constant_state =
|
auto constant_state = dynamic_pointer_cast<Constant>(input.get_source_output().get_node_shared_ptr());
|
||||||
std::dynamic_pointer_cast<ov::opset9::Constant>(input.get_source_output().get_node_shared_ptr());
|
|
||||||
if (constant_state == nullptr)
|
if (constant_state == nullptr)
|
||||||
return false;
|
return false;
|
||||||
const auto& constant_shape = constant_state->get_shape();
|
const auto& constant_shape = constant_state->get_shape();
|
||||||
OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape);
|
|
||||||
if (constant_shape[0] != 1)
|
if (constant_shape[0] != 1)
|
||||||
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
|
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
const auto& constant_copy = constant_state->copy_with_new_inputs({});
|
const auto& constant_copy = constant_state->copy_with_new_inputs({});
|
||||||
const auto& broadcast_by_batch = std::make_shared<ov::opset9::Broadcast>(
|
const auto& broadcast_by_batch = make_shared<Broadcast>(
|
||||||
constant_copy,
|
constant_copy,
|
||||||
std::make_shared<ov::opset9::Concat>(
|
make_shared<Concat>(ngraph::NodeVector{batch_delivering_node,
|
||||||
ngraph::NodeVector{batch_delivering_node,
|
ngraph::op::util::make_try_fold<Gather>(
|
||||||
ngraph::op::util::make_try_fold<ov::opset9::Gather>(
|
ngraph::op::util::make_try_fold<ShapeOf>(constant_copy),
|
||||||
ngraph::op::util::make_try_fold<ov::opset9::ShapeOf>(constant_copy),
|
Constant::create(ov::element::i64, ov::Shape{1}, {1}),
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}),
|
Constant::create(ov::element::i64, ov::Shape{}, {0}))},
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))},
|
0));
|
||||||
0));
|
|
||||||
input.replace_source_output(broadcast_by_batch->output(0));
|
input.replace_source_output(broadcast_by_batch->output(0));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr<ov::opset9::TensorIterator>& ti,
|
bool relax_batch_for_initial_states_of_lstm_in_ti(const shared_ptr<TensorIterator>& ti,
|
||||||
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
const shared_ptr<LSTMCell>& lstm_cell) {
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
|
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
|
||||||
if (batch_delivering_node == nullptr)
|
if (batch_delivering_node == nullptr)
|
||||||
return rewritten;
|
return rewritten;
|
||||||
if (auto init_hidden_state =
|
if (auto init_hidden_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
|
||||||
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
|
|
||||||
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
|
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
|
||||||
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
|
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
|
||||||
}
|
}
|
||||||
if (auto init_cell_state =
|
if (auto init_cell_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
|
||||||
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
|
|
||||||
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
|
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
|
||||||
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
|
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
|
||||||
}
|
}
|
||||||
return rewritten;
|
return rewritten;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
bool relax_batch_for_initial_states_of_lstm(const shared_ptr<LSTMCell>& lstm_cell) {
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(lstm_cell->get_input_source_output(0));
|
const auto& batched_shape = make_shared<ShapeOf>(lstm_cell->get_input_source_output(0));
|
||||||
const auto& batch_delivering_node =
|
const auto& batch_delivering_node = make_shared<Gather>(batched_shape,
|
||||||
std::make_shared<ov::opset9::Gather>(batched_shape,
|
Constant::create(ov::element::i64, ov::Shape{1}, {0}),
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
|
Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
||||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
|
||||||
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
|
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
|
||||||
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
|
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
|
||||||
return rewritten;
|
return rewritten;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f) {
|
||||||
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
|
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
for (auto& node : f->get_ordered_ops()) {
|
for (auto& node : f->get_ordered_ops()) {
|
||||||
// Recursively apply transformation for sub-graph based operations
|
// Recursively apply transformation for sub-graph based operations
|
||||||
if (const auto& sub_graph_node = std::dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
|
if (const auto& sub_graph_node = dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
|
||||||
if (const auto& sub_graph = sub_graph_node->get_function())
|
if (const auto& sub_graph = sub_graph_node->get_function())
|
||||||
rewritten |= run_on_model(sub_graph);
|
rewritten |= run_on_model(sub_graph);
|
||||||
|
|
||||||
// Case without TI (LSTMCell and Constant are in the same ov::Model)
|
// Case without TI (LSTMCell and Constant are in the same ov::Model)
|
||||||
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(node))
|
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(node))
|
||||||
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
|
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
|
||||||
|
|
||||||
// Case with TI (LSTMCell and Constant are in different ov::Model objects)
|
// Case with TI (LSTMCell and Constant are in different ov::Model objects)
|
||||||
if (auto ti = std::dynamic_pointer_cast<ov::opset9::TensorIterator>(node)) {
|
if (auto ti = dynamic_pointer_cast<TensorIterator>(node)) {
|
||||||
auto body = ti->get_body();
|
auto body = ti->get_body();
|
||||||
OPENVINO_ASSERT(body, "TensorIterator must have body network");
|
if (body == nullptr)
|
||||||
|
continue;
|
||||||
for (const auto& body_node : body->get_ordered_ops())
|
for (const auto& body_node : body->get_ordered_ops())
|
||||||
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(body_node))
|
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(body_node))
|
||||||
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
|
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,16 +2,18 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ngraph/opsets/opset9.hpp>
|
#include "transformations/smart_reshape/reshape_sinking.hpp"
|
||||||
#include <ngraph/pattern/matcher.hpp>
|
|
||||||
#include <ngraph/pattern/op/or.hpp>
|
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
|
||||||
#include <ngraph/rt_info.hpp>
|
|
||||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/or.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
|
||||||
ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
using namespace std;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
|
||||||
|
ov::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
||||||
MATCHER_SCOPE(ReshapeSinkingMatMul);
|
MATCHER_SCOPE(ReshapeSinkingMatMul);
|
||||||
/* Original graph: Transformed graph:
|
/* Original graph: Transformed graph:
|
||||||
*
|
*
|
||||||
@ -25,22 +27,20 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
* | shape=[1, S, O] | shape=[B, S, O]
|
* | shape=[1, S, O] | shape=[B, S, O]
|
||||||
*/
|
*/
|
||||||
auto any_input = pattern::any_input(pattern::has_static_rank());
|
auto any_input = pattern::any_input(pattern::has_static_rank());
|
||||||
auto reshape_label = ngraph::pattern::wrap_type<opset9::Reshape>(
|
auto reshape_label =
|
||||||
{pattern::any_input(), ngraph::pattern::wrap_type<opset9::Constant>()},
|
ov::pass::pattern::wrap_type<Reshape>({pattern::any_input(), ov::pass::pattern::wrap_type<Constant>()},
|
||||||
pattern::rank_equals(2));
|
pattern::rank_equals(2));
|
||||||
|
|
||||||
auto matmul_label =
|
auto matmul_label = ov::pass::pattern::wrap_type<MatMul>({reshape_label, ov::pass::pattern::wrap_type<Constant>()},
|
||||||
ngraph::pattern::wrap_type<opset9::MatMul>({reshape_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
pattern::rank_equals(2));
|
||||||
pattern::rank_equals(2));
|
auto add_label = ov::pass::pattern::wrap_type<Add>({matmul_label, ov::pass::pattern::wrap_type<Constant>()},
|
||||||
auto add_label =
|
pattern::rank_equals(2));
|
||||||
ngraph::pattern::wrap_type<opset9::Add>({matmul_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
|
||||||
pattern::rank_equals(2));
|
|
||||||
|
|
||||||
auto matmul_or_matmul_add_label = std::make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
|
auto matmul_or_matmul_add_label = make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
|
||||||
|
|
||||||
auto reshape_1_label = ngraph::pattern::wrap_type<opset9::Reshape>(
|
auto reshape_1_label =
|
||||||
{matmul_or_matmul_add_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
ov::pass::pattern::wrap_type<Reshape>({matmul_or_matmul_add_label, ov::pass::pattern::wrap_type<Constant>()},
|
||||||
pattern::has_static_rank());
|
pattern::has_static_rank());
|
||||||
|
|
||||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||||
auto pattern_to_node = m.get_pattern_map();
|
auto pattern_to_node = m.get_pattern_map();
|
||||||
@ -48,7 +48,7 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
// check first Reshape eligibility: has a constant output pattern in a form of [-1, K]
|
// check first Reshape eligibility: has a constant output pattern in a form of [-1, K]
|
||||||
auto reshape = pattern_to_node.at(reshape_label);
|
auto reshape = pattern_to_node.at(reshape_label);
|
||||||
int64_t K = -1;
|
int64_t K = -1;
|
||||||
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape->get_input_node_shared_ptr(1))) {
|
if (const auto& constant = dynamic_pointer_cast<Constant>(reshape->get_input_node_shared_ptr(1))) {
|
||||||
auto output_pattern_vector = constant->cast_vector<int64_t>();
|
auto output_pattern_vector = constant->cast_vector<int64_t>();
|
||||||
if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1)
|
if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1)
|
||||||
return false;
|
return false;
|
||||||
@ -66,11 +66,11 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
// check matmul eligibility: has constant second input in a form of [O, K]
|
// check matmul eligibility: has constant second input in a form of [O, K]
|
||||||
auto matmul = std::dynamic_pointer_cast<opset9::MatMul>(pattern_to_node.at(matmul_label));
|
auto matmul = dynamic_pointer_cast<MatMul>(pattern_to_node.at(matmul_label));
|
||||||
if (!matmul || matmul->get_transpose_a())
|
if (!matmul || matmul->get_transpose_a())
|
||||||
return false;
|
return false;
|
||||||
int64_t O = -1;
|
int64_t O = -1;
|
||||||
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(matmul->get_input_node_shared_ptr(1))) {
|
if (const auto& constant = dynamic_pointer_cast<Constant>(matmul->get_input_node_shared_ptr(1))) {
|
||||||
const auto& constant_shape = constant->get_shape();
|
const auto& constant_shape = constant->get_shape();
|
||||||
if (constant_shape.size() != 2)
|
if (constant_shape.size() != 2)
|
||||||
return false;
|
return false;
|
||||||
@ -86,10 +86,10 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
// check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't
|
// check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't
|
||||||
// broadcast first input)
|
// broadcast first input)
|
||||||
if (pattern_to_node.count(add_label)) {
|
if (pattern_to_node.count(add_label)) {
|
||||||
auto add = std::dynamic_pointer_cast<opset9::Add>(pattern_to_node.at(add_label));
|
auto add = dynamic_pointer_cast<Add>(pattern_to_node.at(add_label));
|
||||||
if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY)
|
if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY)
|
||||||
return false;
|
return false;
|
||||||
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(add->get_input_node_shared_ptr(1));
|
const auto& constant = dynamic_pointer_cast<Constant>(add->get_input_node_shared_ptr(1));
|
||||||
if (!constant)
|
if (!constant)
|
||||||
return false;
|
return false;
|
||||||
const auto& constant_shape = constant->get_shape();
|
const auto& constant_shape = constant->get_shape();
|
||||||
@ -106,19 +106,17 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
// input_shape of the pattern except for the batch and last dimension
|
// input_shape of the pattern except for the batch and last dimension
|
||||||
auto reshape_1 = m.get_match_root();
|
auto reshape_1 = m.get_match_root();
|
||||||
|
|
||||||
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape_1->get_input_node_shared_ptr(1));
|
const auto& constant = dynamic_pointer_cast<Constant>(reshape_1->get_input_node_shared_ptr(1));
|
||||||
if (constant == nullptr)
|
if (constant == nullptr)
|
||||||
return false;
|
return false;
|
||||||
auto output_pattern = constant->cast_vector<int64_t>();
|
auto output_pattern = constant->cast_vector<int64_t>();
|
||||||
if (!std::all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
|
if (output_pattern.size() != input_rank)
|
||||||
|
return false;
|
||||||
|
if (!all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
|
||||||
return i > 0;
|
return i > 0;
|
||||||
}))
|
}))
|
||||||
return false;
|
return false;
|
||||||
if (output_pattern.size() != input_rank)
|
for (size_t i = 1; i < input_rank; ++i) {
|
||||||
return false;
|
|
||||||
for (size_t i = 0; i < input_rank; ++i) {
|
|
||||||
if (i == 0)
|
|
||||||
continue;
|
|
||||||
if (i + 1 == input_rank) {
|
if (i + 1 == input_rank) {
|
||||||
if (output_pattern[i] != O)
|
if (output_pattern[i] != O)
|
||||||
return false;
|
return false;
|
||||||
@ -129,28 +127,26 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto first_reshape = dynamic_pointer_cast<Reshape>(reshape);
|
||||||
|
auto second_reshape = dynamic_pointer_cast<Reshape>(reshape_1);
|
||||||
|
if (!first_reshape || !second_reshape)
|
||||||
|
return false;
|
||||||
|
|
||||||
// this is the pattern we are looking for! performing the transformation
|
// this is the pattern we are looking for! performing the transformation
|
||||||
auto first_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape);
|
|
||||||
if (!first_reshape)
|
|
||||||
return false;
|
|
||||||
first_reshape->set_special_zero(true);
|
first_reshape->set_special_zero(true);
|
||||||
auto second_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape_1);
|
|
||||||
if (!second_reshape)
|
|
||||||
return false;
|
|
||||||
second_reshape->set_special_zero(true);
|
second_reshape->set_special_zero(true);
|
||||||
|
|
||||||
std::vector<int64_t> output_pattern_vector(input_rank - 1, 0);
|
vector<int64_t> output_pattern_vector(input_rank - 1, 0);
|
||||||
output_pattern_vector.push_back(K);
|
output_pattern_vector.push_back(K);
|
||||||
auto new_reshape_constant =
|
auto new_reshape_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
|
||||||
opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
|
|
||||||
reshape->input(1).replace_source_output(new_reshape_constant->output(0));
|
reshape->input(1).replace_source_output(new_reshape_constant->output(0));
|
||||||
|
|
||||||
output_pattern[0] = 0;
|
output_pattern[0] = 0;
|
||||||
auto new_reshape_1_constant = opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
|
auto new_reshape_1_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
|
||||||
reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0));
|
reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0));
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_1_label, matcher_name);
|
auto m = make_shared<ov::pass::pattern::Matcher>(reshape_1_label, matcher_name);
|
||||||
register_matcher(m, callback);
|
register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,8 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr<ngraph::Func
|
|||||||
static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
|
static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
|
||||||
static_manager.register_pass<ngraph::pass::TransposeMatMul>();
|
static_manager.register_pass<ngraph::pass::TransposeMatMul>();
|
||||||
static_manager.register_pass<ngraph::pass::BroadcastConstRangeReplacement>();
|
static_manager.register_pass<ngraph::pass::BroadcastConstRangeReplacement>();
|
||||||
static_manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
|
static_manager.register_pass<ov::pass::LSTMStatesBroadcast>();
|
||||||
static_manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
|
static_manager.register_pass<ov::pass::ReshapeSinkingMatMul>();
|
||||||
static_manager.run_passes(f);
|
static_manager.run_passes(f);
|
||||||
|
|
||||||
ngraph::pass::Manager dynamic_manager;
|
ngraph::pass::Manager dynamic_manager;
|
||||||
|
@ -4,20 +4,13 @@
|
|||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <openvino/core/model.hpp>
|
#include "openvino/core/model.hpp"
|
||||||
#include <openvino/opsets/opset9.hpp>
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
|
||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
template<class T>
|
using namespace std;
|
||||||
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
|
using namespace ov::opset9;
|
||||||
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
|
|
||||||
return ov::opset9::Constant::create(et, shape, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::PartialShape shape) {
|
|
||||||
return ov::opset9::Constant::create(et, shape.to_shape(), {0});
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LSTMStatesAttributes {
|
struct LSTMStatesAttributes {
|
||||||
ov::element::Type_t data_et;
|
ov::element::Type_t data_et;
|
||||||
@ -32,18 +25,18 @@ class LSTMStatesBroadcastTest
|
|||||||
TEST_P(LSTMStatesBroadcastTest, BareLSTM) {
|
TEST_P(LSTMStatesBroadcastTest, BareLSTM) {
|
||||||
auto p = GetParam();
|
auto p = GetParam();
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> model(nullptr);
|
shared_ptr<ov::Model> model(nullptr);
|
||||||
{
|
{
|
||||||
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
|
auto parameter = make_shared<Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
|
||||||
auto initial_hidden_state = create_zero_constant(p.data_et, {1, p.hidden_size});
|
auto initial_hidden_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape());
|
||||||
auto initial_cell_state = create_zero_constant(p.data_et, {1, p.hidden_size});
|
auto initial_cell_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape());
|
||||||
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
|
auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape());
|
||||||
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
|
auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape());
|
||||||
|
|
||||||
auto cell = std::make_shared<ov::opset9::LSTMCell>(
|
auto cell = make_shared<LSTMCell>(
|
||||||
parameter, initial_hidden_state, initial_cell_state, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
parameter, initial_hidden_state, initial_cell_state, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
||||||
|
|
||||||
model = std::make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
|
model = make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
|
||||||
}
|
}
|
||||||
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size}));
|
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size}));
|
||||||
}
|
}
|
||||||
@ -56,29 +49,29 @@ class LSTMStatesBroadcastTestWithTI
|
|||||||
TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
|
TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
|
||||||
auto p = GetParam();
|
auto p = GetParam();
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> model(nullptr);
|
shared_ptr<ov::Model> model(nullptr);
|
||||||
{
|
{
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
|
auto X = make_shared<Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
|
||||||
auto H_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
|
auto H_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape());
|
||||||
auto C_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
|
auto C_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape());
|
||||||
|
|
||||||
auto Xi = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
|
auto Xi = make_shared<Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
|
||||||
auto H_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
auto H_t = make_shared<Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
||||||
auto C_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
auto C_t = make_shared<Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
||||||
|
|
||||||
// Body
|
// Body
|
||||||
auto squeeze = std::make_shared<ov::opset9::Squeeze>(Xi, create_constant<int64_t>({1}));
|
auto squeeze = make_shared<Squeeze>(Xi, create_constant<int64_t>({1}));
|
||||||
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
|
auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape());
|
||||||
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
|
auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape());
|
||||||
|
|
||||||
auto lstm_cell = std::make_shared<ov::opset9::LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
auto lstm_cell = make_shared<LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
||||||
auto res_1 = std::make_shared<ov::opset9::Result>(lstm_cell->output(0));
|
auto res_1 = make_shared<Result>(lstm_cell->output(0));
|
||||||
auto unsqueeze = std::make_shared<ov::opset9::Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
|
auto unsqueeze = make_shared<Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
|
||||||
auto res_2 = std::make_shared<ov::opset9::Result>(unsqueeze);
|
auto res_2 = make_shared<Result>(unsqueeze);
|
||||||
auto res_3 = std::make_shared<ov::opset9::Result>(lstm_cell->output(1));
|
auto res_3 = make_shared<Result>(lstm_cell->output(1));
|
||||||
auto body = std::make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
|
auto body = make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
|
||||||
|
|
||||||
auto tensor_iterator = std::make_shared<ov::opset9::TensorIterator>();
|
auto tensor_iterator = make_shared<TensorIterator>();
|
||||||
tensor_iterator->set_body(body);
|
tensor_iterator->set_body(body);
|
||||||
|
|
||||||
tensor_iterator->set_merged_input(C_t, C_init, res_3);
|
tensor_iterator->set_merged_input(C_t, C_init, res_3);
|
||||||
@ -88,16 +81,15 @@ TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
|
|||||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
|
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
|
||||||
|
|
||||||
auto res_ti_1 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(1));
|
auto res_ti_1 = make_shared<Result>(tensor_iterator->output(1));
|
||||||
auto res_ti_2 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(0));
|
auto res_ti_2 = make_shared<Result>(tensor_iterator->output(0));
|
||||||
model = std::make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
|
model = make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
|
||||||
ov::ParameterVector{X});
|
ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size});
|
|
||||||
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}));
|
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}));
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<LSTMStatesAttributes> params = {
|
static vector<LSTMStatesAttributes> params = {
|
||||||
LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}},
|
LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}},
|
||||||
LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}},
|
LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}},
|
||||||
};
|
};
|
||||||
|
@ -4,21 +4,11 @@
|
|||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <openvino/core/model.hpp>
|
#include "openvino/core/model.hpp"
|
||||||
#include <openvino/opsets/opset9.hpp>
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
|
||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
template<class T>
|
|
||||||
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
|
|
||||||
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
|
|
||||||
return ov::opset9::Constant::create(et, shape, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::Shape shape) {
|
|
||||||
return ov::opset9::Constant::create(et, shape, {0});
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ReshapeSinkingAttributes {
|
struct ReshapeSinkingAttributes {
|
||||||
ov::element::Type_t data_et;
|
ov::element::Type_t data_et;
|
||||||
ov::PartialShape input_shape;
|
ov::PartialShape input_shape;
|
||||||
|
@ -64,3 +64,7 @@ void check_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_p
|
|||||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh, true);
|
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh, true);
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::opset8::Constant> create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape) {
|
||||||
|
return ov::opset8::Constant::create(et, shape, {0});
|
||||||
|
}
|
@ -69,3 +69,11 @@ size_t count_ops_of_type(const std::shared_ptr<ngraph::Function>& f) {
|
|||||||
|
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
std::shared_ptr<ov::opset8::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
|
||||||
|
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
|
||||||
|
return ov::opset8::Constant::create(et, shape, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::opset8::Constant> create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape);
|
||||||
|
Loading…
Reference in New Issue
Block a user