Remove deprecated classes/methods usage from Legacy/Tests/VPUPlugin (#3907)
* Fix legacy converter for Mul/Add/Sub ops * Updated VPU plugin to use pass_config;Updated tests to avoid legacy classes/methods * Updated VPU pipeline
This commit is contained in:
parent
c6c1503ba1
commit
08eea31f45
@ -11,17 +11,6 @@
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "legacy/ngraph_ops/scaleshift.hpp"
|
||||
#include "legacy/ngraph_ops/eltwise.hpp"
|
||||
#include "legacy/ngraph_ops/power.hpp"
|
||||
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
@ -33,273 +22,5 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMulOrAddFinally);
|
||||
class ngraph::pass::ConvertMulOrAddFinally: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
// This pass finally converts single Multiply and Add operations to ScaleShift or Power operation
|
||||
ConvertMulOrAddFinally();
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
void convert_mul_or_add_finally();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
bool convert_to_eltwise(std::shared_ptr<T> & node,
|
||||
ngraph::Output<ngraph::Node> data1,
|
||||
ngraph::Output<ngraph::Node> data2) {
|
||||
ELTWISE_TYPE et;
|
||||
if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
et = ELTWISE_TYPE::Prod;
|
||||
} else if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
et = ELTWISE_TYPE::Sum;
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
et = ELTWISE_TYPE::Sub;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto eltwise = std::make_shared<ngraph::op::Eltwise>(data1, data2, et, node->output(0).get_element_type());
|
||||
eltwise->set_friendly_name(node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(node, eltwise);
|
||||
ngraph::replace_node(node, eltwise);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ngraph::matcher_pass_callback get_callback() {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
|
||||
"Unsupported template parameter. Only Add or Multiply allowed!");
|
||||
|
||||
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
|
||||
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto output_shape = lin_op->output(0).get_partial_shape();
|
||||
const auto output_shape_rank = output_shape.rank().get_length();
|
||||
|
||||
const auto intInputs = !lin_op->get_input_element_type(0).is_real() &&
|
||||
!lin_op->get_input_element_type(1).is_real();
|
||||
|
||||
if (!lin_op->get_element_type().is_real() || intInputs) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Constant> const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(
|
||||
lin_op->input(0).get_source_output().get_node_shared_ptr());
|
||||
auto data_node = lin_op->input(1).get_source_output();
|
||||
if (!const_node) {
|
||||
const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lin_op->input(1).get_source_output().get_node_shared_ptr());
|
||||
data_node = lin_op->input(0).get_source_output();
|
||||
if (!const_node) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
}
|
||||
|
||||
/* This lambda checks data and constant shapes for broadcasting
|
||||
For example:
|
||||
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
|
||||
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
|
||||
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
|
||||
*/
|
||||
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
|
||||
if (data_pshape.rank().is_dynamic() || const_shape.size() > static_cast<size_t>(data_pshape.rank().get_length())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ngraph::Dimension> data_shape(data_pshape);
|
||||
|
||||
auto const_shape_it = const_shape.rbegin();
|
||||
auto data_shape_it = data_shape.rbegin();
|
||||
|
||||
while (const_shape_it != const_shape.rend()) {
|
||||
auto data_dim = *data_shape_it;
|
||||
auto const_dim = *const_shape_it;
|
||||
|
||||
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
|
||||
DYN - 64 - TRUE
|
||||
DYN - 1 - FALSE
|
||||
64 - 1 - FALSE
|
||||
1 - 64 - TRUE
|
||||
64 - 64 - FALSE
|
||||
*/
|
||||
if ((data_dim.is_dynamic() && const_dim != 1) ||
|
||||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
++const_shape_it;
|
||||
++data_shape_it;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
|
||||
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
|
||||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
|
||||
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
|
||||
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
|
||||
if (ret_status) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = check_constant(const_node, data_node.get_partial_shape());
|
||||
|
||||
auto checkElementwise = [](const std::shared_ptr<ngraph::Node>& elementwise) -> bool {
|
||||
const ngraph::PartialShape partialShape = elementwise->get_input_partial_shape(0);
|
||||
if (partialShape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Constant> constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(1));
|
||||
if (constant == nullptr) {
|
||||
constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(0));
|
||||
}
|
||||
if (constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ngraph::Shape constShape = constant->get_output_shape(0);
|
||||
if ((constShape.size() > 5ul)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((constShape.size() <= 1ul) || (std::all_of(constShape.begin(), constShape.end(), [](const size_t value) { return value == 1ul; }))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const ngraph::Shape shape = partialShape.to_shape();
|
||||
if (constShape.size() == shape.size()) {
|
||||
if ((constShape[0] != 1ul) || (constShape[1] != shape[1])) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 2ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (constShape.size() == (shape.size() - 1)) {
|
||||
if (constShape[0] != shape[1]) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 1ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool is_dequantization = (lin_op->get_rt_info().count("DEQUANTIZATION") != 0) && checkElementwise(lin_op);
|
||||
|
||||
if (!is_dequantization && (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4))) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
|
||||
// TODO: if all values in Constant are equal the best way is to convert this Eltwise to Power
|
||||
if (res == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
|
||||
auto weights_et = const_node->get_element_type();
|
||||
auto weights_shape = const_node->get_shape();
|
||||
|
||||
// In case of Add we create fake weights with 1, in case of Multiply we create fake bias with 0
|
||||
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
|
||||
if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
|
||||
auto biases_in = ngraph::op::util::normalize_constant(const_node, output_shape);
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
std::shared_ptr<ngraph::Node> new_const_node = std::make_shared<ngraph::opset1::Multiply>(
|
||||
ngraph::op::util::normalize_constant(const_node, output_shape),
|
||||
ngraph::opset1::Constant::create(weights_et, ngraph::Shape{ 1 }, { -1 }));
|
||||
|
||||
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
|
||||
auto biases_in = new_const_node;
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(const_node, output_shape);
|
||||
auto biases_in = ngraph::op::util::normalize_constant(bias, output_shape);
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
scaleshift->set_friendly_name(lin_op->get_friendly_name());
|
||||
ngraph::copy_runtime_info(m.get_match_root(), scaleshift);
|
||||
ngraph::replace_node(m.get_match_root(), scaleshift);
|
||||
} else {
|
||||
float value;
|
||||
if (!ngraph::op::util::get_single_value(const_node, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// In case Add we create fake scale equal to 1, in case of Multiply we create fake shift equal to 0
|
||||
std::shared_ptr<ngraph::op::PowerIE> power;
|
||||
if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, value, lin_op->get_output_element_type(0));
|
||||
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, value, 0.0f, lin_op->get_output_element_type(0));
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, -value, lin_op->get_output_element_type(0));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
power->set_friendly_name(lin_op->get_friendly_name());
|
||||
ngraph::copy_runtime_info(m.get_match_root(), power);
|
||||
ngraph::replace_node(m.get_match_root(), power);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
return callback;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ngraph::pass::ConvertMulOrAddFinally::convert_mul_or_add_finally() {
|
||||
auto data_batch_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 2, 1, 1});
|
||||
auto data_batch_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 2, 1, 1});
|
||||
|
||||
auto lin_op = std::make_shared<T>(data_batch_1, data_batch_2);
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, get_callback<T>(), PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
@ -3,11 +3,295 @@
|
||||
//
|
||||
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
|
||||
|
||||
#include "legacy/ngraph_ops/scaleshift.hpp"
|
||||
#include "legacy/ngraph_ops/eltwise.hpp"
|
||||
#include "legacy/ngraph_ops/power.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulOrAddFinally, "ConvertMulOrAddFinally", 0);
|
||||
|
||||
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
|
||||
convert_mul_or_add_finally<ngraph::opset1::Add>();
|
||||
convert_mul_or_add_finally<ngraph::opset1::Subtract>();
|
||||
convert_mul_or_add_finally<ngraph::opset1::Multiply>();
|
||||
template <typename T>
|
||||
bool convert_to_eltwise(std::shared_ptr<T> & node,
|
||||
ngraph::Output<ngraph::Node> data1,
|
||||
ngraph::Output<ngraph::Node> data2) {
|
||||
ELTWISE_TYPE et;
|
||||
if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
et = ELTWISE_TYPE::Prod;
|
||||
} else if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
et = ELTWISE_TYPE::Sum;
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
et = ELTWISE_TYPE::Sub;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto eltwise = std::make_shared<ngraph::op::Eltwise>(data1, data2, et, node->output(0).get_element_type());
|
||||
eltwise->set_friendly_name(node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(node, eltwise);
|
||||
ngraph::replace_node(node, eltwise);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ngraph::matcher_pass_callback get_callback() {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
|
||||
"Unsupported template parameter. Only Add or Multiply allowed!");
|
||||
|
||||
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
|
||||
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto output_shape = lin_op->output(0).get_partial_shape();
|
||||
const auto output_shape_rank = output_shape.rank().get_length();
|
||||
|
||||
const auto intInputs = !lin_op->get_input_element_type(0).is_real() &&
|
||||
!lin_op->get_input_element_type(1).is_real();
|
||||
|
||||
if (!lin_op->get_element_type().is_real() || intInputs) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Constant> const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(
|
||||
lin_op->input(0).get_source_output().get_node_shared_ptr());
|
||||
auto data_node = lin_op->input(1).get_source_output();
|
||||
if (!const_node) {
|
||||
const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lin_op->input(1).get_source_output().get_node_shared_ptr());
|
||||
data_node = lin_op->input(0).get_source_output();
|
||||
if (!const_node) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
}
|
||||
|
||||
/* This lambda checks data and constant shapes for broadcasting
|
||||
For example:
|
||||
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
|
||||
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
|
||||
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
|
||||
*/
|
||||
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
|
||||
if (data_pshape.rank().is_dynamic() || const_shape.size() > static_cast<size_t>(data_pshape.rank().get_length())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ngraph::Dimension> data_shape(data_pshape);
|
||||
|
||||
auto const_shape_it = const_shape.rbegin();
|
||||
auto data_shape_it = data_shape.rbegin();
|
||||
|
||||
while (const_shape_it != const_shape.rend()) {
|
||||
auto data_dim = *data_shape_it;
|
||||
auto const_dim = *const_shape_it;
|
||||
|
||||
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
|
||||
DYN - 64 - TRUE
|
||||
DYN - 1 - FALSE
|
||||
64 - 1 - FALSE
|
||||
1 - 64 - TRUE
|
||||
64 - 64 - FALSE
|
||||
*/
|
||||
if ((data_dim.is_dynamic() && const_dim != 1) ||
|
||||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
++const_shape_it;
|
||||
++data_shape_it;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
|
||||
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
|
||||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
|
||||
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
|
||||
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
|
||||
if (ret_status) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = check_constant(const_node, data_node.get_partial_shape());
|
||||
|
||||
auto checkElementwise = [](const std::shared_ptr<ngraph::Node>& elementwise) -> bool {
|
||||
const ngraph::PartialShape partialShape = elementwise->get_input_partial_shape(0);
|
||||
if (partialShape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Constant> constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(1));
|
||||
if (constant == nullptr) {
|
||||
constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(0));
|
||||
}
|
||||
if (constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ngraph::Shape constShape = constant->get_output_shape(0);
|
||||
if ((constShape.size() > 5ul)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((constShape.size() <= 1ul) || (std::all_of(constShape.begin(), constShape.end(), [](const size_t value) { return value == 1ul; }))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const ngraph::Shape shape = partialShape.to_shape();
|
||||
if (constShape.size() == shape.size()) {
|
||||
if ((constShape[0] != 1ul) || (constShape[1] != shape[1])) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 2ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (constShape.size() == (shape.size() - 1)) {
|
||||
if (constShape[0] != shape[1]) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 1ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool is_dequantization = (lin_op->get_rt_info().count("DEQUANTIZATION") != 0) && checkElementwise(lin_op);
|
||||
|
||||
if (!is_dequantization && (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4))) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
}
|
||||
|
||||
// TODO: if all values in Constant are equal the best way is to convert this Eltwise to Power
|
||||
if (res == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
|
||||
auto weights_et = const_node->get_element_type();
|
||||
auto weights_shape = const_node->get_shape();
|
||||
|
||||
// In case of Add we create fake weights with 1, in case of Multiply we create fake bias with 0
|
||||
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
|
||||
if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
|
||||
auto biases_in = ngraph::op::util::normalize_constant(const_node, output_shape);
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
std::shared_ptr<ngraph::Node> new_const_node = std::make_shared<ngraph::opset1::Multiply>(
|
||||
ngraph::op::util::normalize_constant(const_node, output_shape),
|
||||
ngraph::opset1::Constant::create(weights_et, ngraph::Shape{ 1 }, { -1 }));
|
||||
|
||||
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
|
||||
auto biases_in = new_const_node;
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
|
||||
auto weights_in = ngraph::op::util::normalize_constant(const_node, output_shape);
|
||||
auto biases_in = ngraph::op::util::normalize_constant(bias, output_shape);
|
||||
if (is_dequantization) {
|
||||
const ngraph::Shape data_shape = data_node.get_shape();
|
||||
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
|
||||
broadcasted_shape[1] = data_shape[1];
|
||||
|
||||
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
|
||||
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
|
||||
}
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
scaleshift->set_friendly_name(lin_op->get_friendly_name());
|
||||
ngraph::copy_runtime_info(m.get_match_root(), scaleshift);
|
||||
ngraph::replace_node(m.get_match_root(), scaleshift);
|
||||
} else {
|
||||
float value;
|
||||
if (!ngraph::op::util::get_single_value(const_node, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// In case Add we create fake scale equal to 1, in case of Multiply we create fake shift equal to 0
|
||||
std::shared_ptr<ngraph::op::PowerIE> power;
|
||||
if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, value, lin_op->get_output_element_type(0));
|
||||
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, value, 0.0f, lin_op->get_output_element_type(0));
|
||||
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
|
||||
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, -value, lin_op->get_output_element_type(0));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
power->set_friendly_name(lin_op->get_friendly_name());
|
||||
ngraph::copy_runtime_info(m.get_match_root(), power);
|
||||
ngraph::replace_node(m.get_match_root(), power);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
return callback;
|
||||
}
|
||||
|
||||
class ConvertAdd: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertAdd() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Add>());
|
||||
register_matcher(m, get_callback<ngraph::opset1::Add>());
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertSub: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertSub() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Subtract>());
|
||||
register_matcher(m, get_callback<ngraph::opset1::Subtract>());
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertMul: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertMul() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Multiply>());
|
||||
register_matcher(m, get_callback<ngraph::opset1::Multiply>());
|
||||
}
|
||||
};
|
||||
|
||||
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
|
||||
add_matcher<ConvertMul>();
|
||||
add_matcher<ConvertAdd>();
|
||||
add_matcher<ConvertSub>();
|
||||
}
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
namespace vpu {
|
||||
|
||||
class EliminateShapeOfAfterDSR : public ngraph::pass::GraphRewrite {
|
||||
class EliminateShapeOfAfterDSR : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
EliminateShapeOfAfterDSR();
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
namespace vpu {
|
||||
|
||||
class MergeSubsequentDSROperations : public ngraph::pass::GraphRewrite {
|
||||
class MergeSubsequentDSROperations : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MergeSubsequentDSROperations();
|
||||
|
@ -13,14 +13,14 @@ NGRAPH_RTTI_DEFINITION(vpu::EliminateShapeOfAfterDSR, "EliminateShapeOfAfterDSR"
|
||||
|
||||
namespace vpu {
|
||||
|
||||
EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() : GraphRewrite() {
|
||||
EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() {
|
||||
// We don't set strict_mode when use pattern Matcher,
|
||||
// so we can set any type and shape for input.
|
||||
auto inputWithAnyTypeAndShape = std::make_shared<ngraph::pattern::op::Label>(
|
||||
ngraph::element::dynamic, ngraph::PartialShape{});
|
||||
auto shapeOfPattern = std::make_shared<ngraph::opset3::ShapeOf>(inputWithAnyTypeAndShape);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
auto shapeOfNode = std::dynamic_pointer_cast<ngraph::opset3::ShapeOf>(m.get_match_root());
|
||||
if (!shapeOfNode) {
|
||||
return false;
|
||||
@ -36,9 +36,7 @@ EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() : GraphRewrite() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(shapeOfPattern, "EliminateShapeOfAfterDSR");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace vpu
|
||||
|
@ -9,8 +9,8 @@ NGRAPH_RTTI_DEFINITION(vpu::MergeSubsequentDSROperations, "MergeSubsequentDSROpe
|
||||
|
||||
namespace vpu {
|
||||
|
||||
MergeSubsequentDSROperations::MergeSubsequentDSROperations() : ngraph::pass::GraphRewrite() {
|
||||
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
MergeSubsequentDSROperations::MergeSubsequentDSROperations() {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
const auto& dsr = std::dynamic_pointer_cast<ngraph::vpu::op::DynamicShapeResolver>(m.get_match_root());
|
||||
if (!dsr) {
|
||||
return false;
|
||||
@ -31,9 +31,7 @@ MergeSubsequentDSROperations::MergeSubsequentDSROperations() : ngraph::pass::Gra
|
||||
ngraph::pattern::has_class<ngraph::vpu::op::DynamicShapeResolver>());
|
||||
|
||||
const auto& matcher = std::make_shared<ngraph::pattern::Matcher>(label, "MergeSubsequentDSROperations");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
add_matcher(matcher, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
} // namespace vpu
|
||||
|
@ -28,6 +28,10 @@
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
#include <transformations/op_conversions/convert_gelu.hpp>
|
||||
#include <transformations/op_conversions/softplus_decomposition.hpp>
|
||||
#include <transformations/op_conversions/convert_minimum_to_power_and_max.hpp>
|
||||
#include <transformations/op_conversions/hswish_decomposition.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
@ -39,6 +43,8 @@
|
||||
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||
#include <legacy/ie_util_internal.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_matmul_to_fc_or_gemm.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
|
||||
#include <vpu/ngraph/transformations/extract_dynamic_batch/extract_dynamic_batch.hpp>
|
||||
|
||||
namespace vpu {
|
||||
@ -154,22 +160,6 @@ ModelPtr FrontEnd::buildInitialModel(const ie::ICNNNetwork& network) {
|
||||
}
|
||||
|
||||
ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
|
||||
// disable transformations for some cases
|
||||
const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
|
||||
const bool casesWithDynamicOrStaticUsage =
|
||||
std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset5::Minimum>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset5::HSwish>(node);
|
||||
|
||||
const bool casesWithOnlyDynamicUsage =
|
||||
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset3::StridedSlice>(node)) &&
|
||||
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());
|
||||
|
||||
return casesWithDynamicOrStaticUsage || casesWithOnlyDynamicUsage;
|
||||
};
|
||||
|
||||
auto nGraphFunc = network.getFunction();
|
||||
// Disable shape inference (WA for generic operations)
|
||||
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
|
||||
@ -195,15 +185,23 @@ ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.get_pass_config()->disable<ngraph::pass::ConvertGatherToGatherIEMatcher>();
|
||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->disable<ngraph::pass::ConvertGatherToGatherIEMatcher>();
|
||||
pass_config->disable<ngraph::pass::ConvertGELU>();
|
||||
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertMinimum>();
|
||||
pass_config->disable<ngraph::pass::HSwishDecomposition>();
|
||||
|
||||
auto transformationPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());
|
||||
};
|
||||
pass_config->set_callback<ngraph::pass::ConvertMatMulToFC,
|
||||
ngraph::pass::ConvertStridedSliceToCropMatcher>(transformationPredicate);
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformationsPredicate);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(nGraphFunc);
|
||||
|
||||
vpu::MergeSubsequentDSROperations().run_on_function(nGraphFunc);
|
||||
|
||||
return InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, network);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
@ -221,6 +221,7 @@ TEST(TransformationTests, ConvertMatMulTest7) {
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
auto pass_config = m.get_pass_config();
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
@ -235,9 +236,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
|
||||
return false;
|
||||
};
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
m.set_callback(callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(callback);
|
||||
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
@ -31,16 +31,9 @@ TEST(TransformationTests, ConvertPadToConv) {
|
||||
auto pad = std::make_shared<opset4::Pad>(input, pad_begin, pad_end, pad_value, pad_mode);
|
||||
f = std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node) != nullptr;
|
||||
};
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPadToGroupConvolution>();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformations_callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -71,17 +64,10 @@ TEST(TransformationTests, ConvertPadToConvNeg1) {
|
||||
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
|
||||
};
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
};
|
||||
|
||||
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPadToGroupConvolution>();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformations_callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -101,17 +87,10 @@ TEST(TransformationTests, ConvertPadToConvNeg2) {
|
||||
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
|
||||
};
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
};
|
||||
|
||||
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPadToGroupConvolution>();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformations_callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -131,17 +110,10 @@ TEST(TransformationTests, ConvertPadToConvNeg3) {
|
||||
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
|
||||
};
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
};
|
||||
|
||||
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPadToGroupConvolution>();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformations_callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -162,17 +134,10 @@ TEST(TransformationTests, ConvertPadToConvNeg4) {
|
||||
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
|
||||
};
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
};
|
||||
|
||||
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPadToGroupConvolution>();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
manager.set_callback(transformations_callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
|
||||
#include <legacy/ngraph_ops/power.hpp>
|
||||
#include <legacy/ngraph_ops/scaleshift.hpp>
|
||||
#include <legacy/ngraph_ops/eltwise.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
@ -21,8 +21,6 @@
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
#include <common_test_utils/test_common.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
|
||||
namespace {
|
||||
@ -46,8 +47,9 @@ protected:
|
||||
ngraph::NodeVector{shapeOf},
|
||||
ngraph::ParameterVector{data, shape},
|
||||
"Actual");
|
||||
|
||||
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
|
||||
manager.run_passes(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
@ -108,8 +110,9 @@ protected:
|
||||
ngraph::NodeVector{shapeOfOutputRelu},
|
||||
ngraph::ParameterVector{data, shape},
|
||||
"Actual");
|
||||
|
||||
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
|
||||
manager.run_passes(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
@ -174,7 +177,9 @@ protected:
|
||||
ngraph::ParameterVector{data, shape},
|
||||
"Actual");
|
||||
|
||||
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
|
||||
manager.run_passes(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,9 @@ TEST(MergeSubsequentDSROperations, smoke_SingleDSRFunction) {
|
||||
"SingleDSRFunction");
|
||||
auto actual = ngraph::clone_function(*reference);
|
||||
|
||||
vpu::MergeSubsequentDSROperations().run_on_function(actual);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||
manager.run_passes(actual);
|
||||
|
||||
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
|
||||
}
|
||||
@ -80,7 +82,9 @@ TEST(MergeSubsequentDSROperations, smoke_DSR_ReLU_DSR_ReLU_DSR) {
|
||||
"DSR_ReLU_DSR_ReLU_DSR");
|
||||
auto actual = ngraph::clone_function(*reference);
|
||||
|
||||
vpu::MergeSubsequentDSROperations().run_on_function(actual);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||
manager.run_passes(actual);
|
||||
|
||||
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
|
||||
}
|
||||
@ -161,7 +165,9 @@ TEST(MergeSubsequentDSROperations, smoke_DSR_ReLU_DSR_DSR) {
|
||||
"DSR_ReLU_DSR_DSR");
|
||||
}
|
||||
|
||||
vpu::MergeSubsequentDSROperations().run_on_function(actual);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<vpu::MergeSubsequentDSROperations>();
|
||||
manager.run_passes(actual);
|
||||
|
||||
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
|
||||
}
|
||||
|
@ -26,19 +26,13 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
|
||||
class ConvertPrecision : public ngraph::pass::GraphRewrite {
|
||||
class ConvertConstantsPrecision : public MatcherPass {
|
||||
public:
|
||||
ConvertPrecision() : GraphRewrite() {
|
||||
convert_constants_precision();
|
||||
convert_parameters_precision();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_constants_precision() {
|
||||
ConvertConstantsPrecision() {
|
||||
auto constant =
|
||||
std::make_shared<ngraph::op::Constant>(element::f32, Shape{1}, std::vector<float>{0});
|
||||
std::make_shared<ngraph::op::Constant>(element::f32, Shape{1}, std::vector<float>{0});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto constant = std::dynamic_pointer_cast<ngraph::op::Constant>(m.get_match_root());
|
||||
if (!constant) {
|
||||
return false;
|
||||
@ -54,16 +48,18 @@ private:
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertPrecision");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertConstantsPrecision");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
void convert_parameters_precision() {
|
||||
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
|
||||
class ConvertParametersPrecision : public MatcherPass {
|
||||
public:
|
||||
ConvertParametersPrecision() {
|
||||
auto constant = std::make_shared<ngraph::op::Parameter>(to, Shape{1});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto parameter = std::dynamic_pointer_cast<ngraph::op::Parameter>(m.get_match_root());
|
||||
if (parameter && parameter->get_element_type() == ngraph::element::Type(from)) {
|
||||
parameter->set_element_type(to);
|
||||
@ -72,10 +68,17 @@ private:
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertPrecision");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertParametersPrecision");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
|
||||
class ConvertPrecision : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
ConvertPrecision() {
|
||||
add_matcher<ConvertConstantsPrecision<from, to>>();
|
||||
add_matcher<ConvertParametersPrecision<from, to>>();
|
||||
}
|
||||
};
|
||||
} // namespace pass
|
||||
|
Loading…
Reference in New Issue
Block a user