Remove deprecated classes/methods usage from Legacy/Tests/VPUPlugin (#3907)

* Fix legacy converter for Mul/Add/Sub ops

* Updated VPU plugin to use pass_config;Updated tests to avoid legacy classes/methods

* Updated VPU pipeline
This commit is contained in:
Gleb Kazantaev 2021-01-21 12:06:07 +03:00 committed by GitHub
parent c6c1503ba1
commit 08eea31f45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 363 additions and 385 deletions

View File

@ -11,17 +11,6 @@
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
#include "legacy/ngraph_ops/scaleshift.hpp"
#include "legacy/ngraph_ops/eltwise.hpp"
#include "legacy/ngraph_ops/power.hpp"
#include "transformations/utils/utils.hpp"
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
namespace ngraph {
namespace pass {
@ -33,273 +22,5 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMulOrAddFinally);
class ngraph::pass::ConvertMulOrAddFinally: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
// This pass finally converts single Multiply and Add operations to ScaleShift or Power operation
ConvertMulOrAddFinally();
private:
template<typename T>
void convert_mul_or_add_finally();
};
template <typename T>
bool convert_to_eltwise(std::shared_ptr<T> & node,
ngraph::Output<ngraph::Node> data1,
ngraph::Output<ngraph::Node> data2) {
ELTWISE_TYPE et;
if (std::is_same<T, ngraph::opset1::Multiply>()) {
et = ELTWISE_TYPE::Prod;
} else if (std::is_same<T, ngraph::opset1::Add>()) {
et = ELTWISE_TYPE::Sum;
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
et = ELTWISE_TYPE::Sub;
} else {
return false;
}
auto eltwise = std::make_shared<ngraph::op::Eltwise>(data1, data2, et, node->output(0).get_element_type());
eltwise->set_friendly_name(node->get_friendly_name());
ngraph::copy_runtime_info(node, eltwise);
ngraph::replace_node(node, eltwise);
return true;
}
template <typename T>
ngraph::matcher_pass_callback get_callback() {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
"Unsupported template parameter. Only Add or Multiply allowed!");
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
return false;
}
const auto output_shape = lin_op->output(0).get_partial_shape();
const auto output_shape_rank = output_shape.rank().get_length();
const auto intInputs = !lin_op->get_input_element_type(0).is_real() &&
!lin_op->get_input_element_type(1).is_real();
if (!lin_op->get_element_type().is_real() || intInputs) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
std::shared_ptr<ngraph::opset1::Constant> const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(
lin_op->input(0).get_source_output().get_node_shared_ptr());
auto data_node = lin_op->input(1).get_source_output();
if (!const_node) {
const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lin_op->input(1).get_source_output().get_node_shared_ptr());
data_node = lin_op->input(0).get_source_output();
if (!const_node) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
}
/* This lambda checks data and constant shapes for broadcasting
For example:
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
*/
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
if (data_pshape.rank().is_dynamic() || const_shape.size() > static_cast<size_t>(data_pshape.rank().get_length())) {
return true;
}
std::vector<ngraph::Dimension> data_shape(data_pshape);
auto const_shape_it = const_shape.rbegin();
auto data_shape_it = data_shape.rbegin();
while (const_shape_it != const_shape.rend()) {
auto data_dim = *data_shape_it;
auto const_dim = *const_shape_it;
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
DYN - 64 - TRUE
DYN - 1 - FALSE
64 - 1 - FALSE
1 - 64 - TRUE
64 - 64 - FALSE
*/
if ((data_dim.is_dynamic() && const_dim != 1) ||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
return true;
}
++const_shape_it;
++data_shape_it;
}
return false;
};
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
if (ret_status) {
return true;
}
}
auto res = check_constant(const_node, data_node.get_partial_shape());
auto checkElementwise = [](const std::shared_ptr<ngraph::Node>& elementwise) -> bool {
const ngraph::PartialShape partialShape = elementwise->get_input_partial_shape(0);
if (partialShape.is_dynamic()) {
return false;
}
std::shared_ptr<ngraph::opset1::Constant> constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(1));
if (constant == nullptr) {
constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(0));
}
if (constant == nullptr) {
return false;
}
const ngraph::Shape constShape = constant->get_output_shape(0);
if ((constShape.size() > 5ul)) {
return false;
}
if ((constShape.size() <= 1ul) || (std::all_of(constShape.begin(), constShape.end(), [](const size_t value) { return value == 1ul; }))) {
return true;
}
const ngraph::Shape shape = partialShape.to_shape();
if (constShape.size() == shape.size()) {
if ((constShape[0] != 1ul) || (constShape[1] != shape[1])) {
return false;
}
for (size_t i = 2ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else if (constShape.size() == (shape.size() - 1)) {
if (constShape[0] != shape[1]) {
return false;
}
for (size_t i = 1ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else {
return false;
}
return true;
};
bool is_dequantization = (lin_op->get_rt_info().count("DEQUANTIZATION") != 0) && checkElementwise(lin_op);
if (!is_dequantization && (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4))) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
// TODO: if all values in Constant are equal the best way is to convert this Eltwise to Power
if (res == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
auto weights_et = const_node->get_element_type();
auto weights_shape = const_node->get_shape();
// In case of Add we create fake weights with 1, in case of Multiply we create fake bias with 0
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
if (std::is_same<T, ngraph::opset1::Add>()) {
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(const_node, output_shape);
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
std::shared_ptr<ngraph::Node> new_const_node = std::make_shared<ngraph::opset1::Multiply>(
ngraph::op::util::normalize_constant(const_node, output_shape),
ngraph::opset1::Constant::create(weights_et, ngraph::Shape{ 1 }, { -1 }));
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
auto biases_in = new_const_node;
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
auto weights_in = ngraph::op::util::normalize_constant(const_node, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(bias, output_shape);
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else {
return false;
}
scaleshift->set_friendly_name(lin_op->get_friendly_name());
ngraph::copy_runtime_info(m.get_match_root(), scaleshift);
ngraph::replace_node(m.get_match_root(), scaleshift);
} else {
float value;
if (!ngraph::op::util::get_single_value(const_node, value)) {
return false;
}
// In case Add we create fake scale equal to 1, in case of Multiply we create fake shift equal to 0
std::shared_ptr<ngraph::op::PowerIE> power;
if (std::is_same<T, ngraph::opset1::Add>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, value, lin_op->get_output_element_type(0));
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, value, 0.0f, lin_op->get_output_element_type(0));
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, -value, lin_op->get_output_element_type(0));
} else {
return false;
}
power->set_friendly_name(lin_op->get_friendly_name());
ngraph::copy_runtime_info(m.get_match_root(), power);
ngraph::replace_node(m.get_match_root(), power);
}
return true;
};
return callback;
}
template <typename T>
void ngraph::pass::ConvertMulOrAddFinally::convert_mul_or_add_finally() {
auto data_batch_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 2, 1, 1});
auto data_batch_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 2, 1, 1});
auto lin_op = std::make_shared<T>(data_batch_1, data_batch_2);
auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op);
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, get_callback<T>(), PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
}

View File

@ -3,11 +3,295 @@
//
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp"
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
#include "legacy/ngraph_ops/scaleshift.hpp"
#include "legacy/ngraph_ops/eltwise.hpp"
#include "legacy/ngraph_ops/power.hpp"
#include "transformations/utils/utils.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulOrAddFinally, "ConvertMulOrAddFinally", 0);
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
convert_mul_or_add_finally<ngraph::opset1::Add>();
convert_mul_or_add_finally<ngraph::opset1::Subtract>();
convert_mul_or_add_finally<ngraph::opset1::Multiply>();
template <typename T>
bool convert_to_eltwise(std::shared_ptr<T> & node,
ngraph::Output<ngraph::Node> data1,
ngraph::Output<ngraph::Node> data2) {
ELTWISE_TYPE et;
if (std::is_same<T, ngraph::opset1::Multiply>()) {
et = ELTWISE_TYPE::Prod;
} else if (std::is_same<T, ngraph::opset1::Add>()) {
et = ELTWISE_TYPE::Sum;
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
et = ELTWISE_TYPE::Sub;
} else {
return false;
}
auto eltwise = std::make_shared<ngraph::op::Eltwise>(data1, data2, et, node->output(0).get_element_type());
eltwise->set_friendly_name(node->get_friendly_name());
ngraph::copy_runtime_info(node, eltwise);
ngraph::replace_node(node, eltwise);
return true;
}
template <typename T>
ngraph::matcher_pass_callback get_callback() {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
"Unsupported template parameter. Only Add or Multiply allowed!");
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
return false;
}
const auto output_shape = lin_op->output(0).get_partial_shape();
const auto output_shape_rank = output_shape.rank().get_length();
const auto intInputs = !lin_op->get_input_element_type(0).is_real() &&
!lin_op->get_input_element_type(1).is_real();
if (!lin_op->get_element_type().is_real() || intInputs) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
std::shared_ptr<ngraph::opset1::Constant> const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(
lin_op->input(0).get_source_output().get_node_shared_ptr());
auto data_node = lin_op->input(1).get_source_output();
if (!const_node) {
const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lin_op->input(1).get_source_output().get_node_shared_ptr());
data_node = lin_op->input(0).get_source_output();
if (!const_node) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
}
/* This lambda checks data and constant shapes for broadcasting
For example:
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
*/
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
if (data_pshape.rank().is_dynamic() || const_shape.size() > static_cast<size_t>(data_pshape.rank().get_length())) {
return true;
}
std::vector<ngraph::Dimension> data_shape(data_pshape);
auto const_shape_it = const_shape.rbegin();
auto data_shape_it = data_shape.rbegin();
while (const_shape_it != const_shape.rend()) {
auto data_dim = *data_shape_it;
auto const_dim = *const_shape_it;
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
DYN - 64 - TRUE
DYN - 1 - FALSE
64 - 1 - FALSE
1 - 64 - TRUE
64 - 64 - FALSE
*/
if ((data_dim.is_dynamic() && const_dim != 1) ||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
return true;
}
++const_shape_it;
++data_shape_it;
}
return false;
};
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
if (ret_status) {
return true;
}
}
auto res = check_constant(const_node, data_node.get_partial_shape());
auto checkElementwise = [](const std::shared_ptr<ngraph::Node>& elementwise) -> bool {
const ngraph::PartialShape partialShape = elementwise->get_input_partial_shape(0);
if (partialShape.is_dynamic()) {
return false;
}
std::shared_ptr<ngraph::opset1::Constant> constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(1));
if (constant == nullptr) {
constant = ngraph::as_type_ptr<ngraph::opset1::Constant>(elementwise->get_input_node_shared_ptr(0));
}
if (constant == nullptr) {
return false;
}
const ngraph::Shape constShape = constant->get_output_shape(0);
if ((constShape.size() > 5ul)) {
return false;
}
if ((constShape.size() <= 1ul) || (std::all_of(constShape.begin(), constShape.end(), [](const size_t value) { return value == 1ul; }))) {
return true;
}
const ngraph::Shape shape = partialShape.to_shape();
if (constShape.size() == shape.size()) {
if ((constShape[0] != 1ul) || (constShape[1] != shape[1])) {
return false;
}
for (size_t i = 2ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else if (constShape.size() == (shape.size() - 1)) {
if (constShape[0] != shape[1]) {
return false;
}
for (size_t i = 1ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else {
return false;
}
return true;
};
bool is_dequantization = (lin_op->get_rt_info().count("DEQUANTIZATION") != 0) && checkElementwise(lin_op);
if (!is_dequantization && (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4))) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}
// TODO: if all values in Constant are equal the best way is to convert this Eltwise to Power
if (res == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
auto weights_et = const_node->get_element_type();
auto weights_shape = const_node->get_shape();
// In case of Add we create fake weights with 1, in case of Multiply we create fake bias with 0
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
if (std::is_same<T, ngraph::opset1::Add>()) {
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(const_node, output_shape);
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
std::shared_ptr<ngraph::Node> new_const_node = std::make_shared<ngraph::opset1::Multiply>(
ngraph::op::util::normalize_constant(const_node, output_shape),
ngraph::opset1::Constant::create(weights_et, ngraph::Shape{ 1 }, { -1 }));
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
auto weights_in = ngraph::op::util::normalize_constant(weights, output_shape);
auto biases_in = new_const_node;
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
auto weights_in = ngraph::op::util::normalize_constant(const_node, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(bias, output_shape);
if (is_dequantization) {
const ngraph::Shape data_shape = data_node.get_shape();
ngraph::Shape broadcasted_shape = std::vector<size_t>(data_shape.size(), 1ul);
broadcasted_shape[1] = data_shape[1];
weights_in = ngraph::op::util::broadcastTo(weights_in, broadcasted_shape);
biases_in = ngraph::op::util::broadcastTo(biases_in, broadcasted_shape);
}
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
} else {
return false;
}
scaleshift->set_friendly_name(lin_op->get_friendly_name());
ngraph::copy_runtime_info(m.get_match_root(), scaleshift);
ngraph::replace_node(m.get_match_root(), scaleshift);
} else {
float value;
if (!ngraph::op::util::get_single_value(const_node, value)) {
return false;
}
// In case Add we create fake scale equal to 1, in case of Multiply we create fake shift equal to 0
std::shared_ptr<ngraph::op::PowerIE> power;
if (std::is_same<T, ngraph::opset1::Add>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, value, lin_op->get_output_element_type(0));
} else if (std::is_same<T, ngraph::opset1::Multiply>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, value, 0.0f, lin_op->get_output_element_type(0));
} else if (std::is_same<T, ngraph::opset1::Subtract>()) {
power = std::make_shared<ngraph::op::PowerIE>(data_node, 1.0f, 1.0f, -value, lin_op->get_output_element_type(0));
} else {
return false;
}
power->set_friendly_name(lin_op->get_friendly_name());
ngraph::copy_runtime_info(m.get_match_root(), power);
ngraph::replace_node(m.get_match_root(), power);
}
return true;
};
return callback;
}
class ConvertAdd: public ngraph::pass::MatcherPass {
public:
ConvertAdd() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Add>());
register_matcher(m, get_callback<ngraph::opset1::Add>());
}
};
class ConvertSub: public ngraph::pass::MatcherPass {
public:
ConvertSub() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Subtract>());
register_matcher(m, get_callback<ngraph::opset1::Subtract>());
}
};
class ConvertMul: public ngraph::pass::MatcherPass {
public:
ConvertMul() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Multiply>());
register_matcher(m, get_callback<ngraph::opset1::Multiply>());
}
};
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
add_matcher<ConvertMul>();
add_matcher<ConvertAdd>();
add_matcher<ConvertSub>();
}

View File

@ -8,7 +8,7 @@
namespace vpu {
class EliminateShapeOfAfterDSR : public ngraph::pass::GraphRewrite {
class EliminateShapeOfAfterDSR : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateShapeOfAfterDSR();

View File

@ -8,7 +8,7 @@
namespace vpu {
class MergeSubsequentDSROperations : public ngraph::pass::GraphRewrite {
class MergeSubsequentDSROperations : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MergeSubsequentDSROperations();

View File

@ -13,14 +13,14 @@ NGRAPH_RTTI_DEFINITION(vpu::EliminateShapeOfAfterDSR, "EliminateShapeOfAfterDSR"
namespace vpu {
EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() : GraphRewrite() {
EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() {
// We don't set strict_mode when use pattern Matcher,
// so we can set any type and shape for input.
auto inputWithAnyTypeAndShape = std::make_shared<ngraph::pattern::op::Label>(
ngraph::element::dynamic, ngraph::PartialShape{});
auto shapeOfPattern = std::make_shared<ngraph::opset3::ShapeOf>(inputWithAnyTypeAndShape);
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
auto shapeOfNode = std::dynamic_pointer_cast<ngraph::opset3::ShapeOf>(m.get_match_root());
if (!shapeOfNode) {
return false;
@ -36,9 +36,7 @@ EliminateShapeOfAfterDSR::EliminateShapeOfAfterDSR() : GraphRewrite() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(shapeOfPattern, "EliminateShapeOfAfterDSR");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}
} // namespace vpu

View File

@ -9,8 +9,8 @@ NGRAPH_RTTI_DEFINITION(vpu::MergeSubsequentDSROperations, "MergeSubsequentDSROpe
namespace vpu {
MergeSubsequentDSROperations::MergeSubsequentDSROperations() : ngraph::pass::GraphRewrite() {
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
MergeSubsequentDSROperations::MergeSubsequentDSROperations() {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
const auto& dsr = std::dynamic_pointer_cast<ngraph::vpu::op::DynamicShapeResolver>(m.get_match_root());
if (!dsr) {
return false;
@ -31,9 +31,7 @@ MergeSubsequentDSROperations::MergeSubsequentDSROperations() : ngraph::pass::Gra
ngraph::pattern::has_class<ngraph::vpu::op::DynamicShapeResolver>());
const auto& matcher = std::make_shared<ngraph::pattern::Matcher>(label, "MergeSubsequentDSROperations");
NGRAPH_SUPPRESS_DEPRECATED_START
add_matcher(matcher, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(matcher, callback);
}
} // namespace vpu

View File

@ -28,6 +28,10 @@
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
#include <transformations/op_conversions/softplus_decomposition.hpp>
#include <transformations/op_conversions/convert_minimum_to_power_and_max.hpp>
#include <transformations/op_conversions/hswish_decomposition.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
@ -39,6 +43,8 @@
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <legacy/ie_util_internal.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_matmul_to_fc_or_gemm.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
#include <vpu/ngraph/transformations/extract_dynamic_batch/extract_dynamic_batch.hpp>
namespace vpu {
@ -154,22 +160,6 @@ ModelPtr FrontEnd::buildInitialModel(const ie::ICNNNetwork& network) {
}
ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
// disable transformations for some cases
const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
const bool casesWithDynamicOrStaticUsage =
std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
std::dynamic_pointer_cast<const ngraph::opset5::Minimum>(node) ||
std::dynamic_pointer_cast<const ngraph::opset5::HSwish>(node);
const bool casesWithOnlyDynamicUsage =
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) ||
std::dynamic_pointer_cast<const ngraph::opset3::StridedSlice>(node)) &&
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());
return casesWithDynamicOrStaticUsage || casesWithOnlyDynamicUsage;
};
auto nGraphFunc = network.getFunction();
// Disable shape inference (WA for generic operations)
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
@ -195,15 +185,23 @@ ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
manager.get_pass_config()->disable<ngraph::pass::ConvertGatherToGatherIEMatcher>();
manager.register_pass<vpu::MergeSubsequentDSROperations>();
auto pass_config = manager.get_pass_config();
pass_config->disable<ngraph::pass::ConvertGatherToGatherIEMatcher>();
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
pass_config->disable<ngraph::pass::ConvertMinimum>();
pass_config->disable<ngraph::pass::HSwishDecomposition>();
auto transformationPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());
};
pass_config->set_callback<ngraph::pass::ConvertMatMulToFC,
ngraph::pass::ConvertStridedSliceToCropMatcher>(transformationPredicate);
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformationsPredicate);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(nGraphFunc);
vpu::MergeSubsequentDSROperations().run_on_function(nGraphFunc);
return InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, network);
}

View File

@ -13,6 +13,7 @@
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"

View File

@ -221,6 +221,7 @@ TEST(TransformationTests, ConvertMatMulTest7) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
auto pass_config = m.get_pass_config();
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
@ -235,9 +236,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
return false;
};
NGRAPH_SUPPRESS_DEPRECATED_START
m.set_callback(callback);
NGRAPH_SUPPRESS_DEPRECATED_END
pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(callback);
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

View File

@ -31,16 +31,9 @@ TEST(TransformationTests, ConvertPadToConv) {
auto pad = std::make_shared<opset4::Pad>(input, pad_begin, pad_end, pad_value, pad_mode);
f = std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node) != nullptr;
};
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPadToGroupConvolution>();
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformations_callback);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -71,17 +64,10 @@ TEST(TransformationTests, ConvertPadToConvNeg1) {
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
};
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
};
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPadToGroupConvolution>();
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformations_callback);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -101,17 +87,10 @@ TEST(TransformationTests, ConvertPadToConvNeg2) {
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
};
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
};
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPadToGroupConvolution>();
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformations_callback);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -131,17 +110,10 @@ TEST(TransformationTests, ConvertPadToConvNeg3) {
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
};
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
};
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPadToGroupConvolution>();
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformations_callback);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -162,17 +134,10 @@ TEST(TransformationTests, ConvertPadToConvNeg4) {
return std::make_shared<Function>(NodeVector{pad}, ParameterVector{input});
};
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
};
std::shared_ptr<Function> f(get_function()), f_ref(get_function());
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPadToGroupConvolution>();
NGRAPH_SUPPRESS_DEPRECATED_START
manager.set_callback(transformations_callback);
NGRAPH_SUPPRESS_DEPRECATED_END
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));

View File

@ -23,6 +23,7 @@
#include <legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
#include <legacy/ngraph_ops/power.hpp>
#include <legacy/ngraph_ops/scaleshift.hpp>
#include <legacy/ngraph_ops/eltwise.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"

View File

@ -21,8 +21,6 @@
#include "common_test_utils/ngraph_test_utils.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace ngraph;
using namespace std;

View File

@ -13,6 +13,7 @@
#include <common_test_utils/test_common.hpp>
#include <gtest/gtest.h>
#include <ngraph/pass/manager.hpp>
namespace {
@ -46,8 +47,9 @@ protected:
ngraph::NodeVector{shapeOf},
ngraph::ParameterVector{data, shape},
"Actual");
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
ngraph::pass::Manager manager;
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
manager.run_passes(function);
return function;
}
@ -108,8 +110,9 @@ protected:
ngraph::NodeVector{shapeOfOutputRelu},
ngraph::ParameterVector{data, shape},
"Actual");
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
ngraph::pass::Manager manager;
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
manager.run_passes(function);
return function;
}
@ -174,7 +177,9 @@ protected:
ngraph::ParameterVector{data, shape},
"Actual");
vpu::EliminateShapeOfAfterDSR().run_on_function(function);
ngraph::pass::Manager manager;
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
manager.run_passes(function);
return function;
}

View File

@ -34,7 +34,9 @@ TEST(MergeSubsequentDSROperations, smoke_SingleDSRFunction) {
"SingleDSRFunction");
auto actual = ngraph::clone_function(*reference);
vpu::MergeSubsequentDSROperations().run_on_function(actual);
ngraph::pass::Manager manager;
manager.register_pass<vpu::MergeSubsequentDSROperations>();
manager.run_passes(actual);
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
}
@ -80,7 +82,9 @@ TEST(MergeSubsequentDSROperations, smoke_DSR_ReLU_DSR_ReLU_DSR) {
"DSR_ReLU_DSR_ReLU_DSR");
auto actual = ngraph::clone_function(*reference);
vpu::MergeSubsequentDSROperations().run_on_function(actual);
ngraph::pass::Manager manager;
manager.register_pass<vpu::MergeSubsequentDSROperations>();
manager.run_passes(actual);
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
}
@ -161,7 +165,9 @@ TEST(MergeSubsequentDSROperations, smoke_DSR_ReLU_DSR_DSR) {
"DSR_ReLU_DSR_DSR");
}
vpu::MergeSubsequentDSROperations().run_on_function(actual);
ngraph::pass::Manager manager;
manager.register_pass<vpu::MergeSubsequentDSROperations>();
manager.run_passes(actual);
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*reference, *actual));
}

View File

@ -26,19 +26,13 @@ namespace ngraph {
namespace pass {
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
class ConvertPrecision : public ngraph::pass::GraphRewrite {
class ConvertConstantsPrecision : public MatcherPass {
public:
ConvertPrecision() : GraphRewrite() {
convert_constants_precision();
convert_parameters_precision();
}
private:
void convert_constants_precision() {
ConvertConstantsPrecision() {
auto constant =
std::make_shared<ngraph::op::Constant>(element::f32, Shape{1}, std::vector<float>{0});
std::make_shared<ngraph::op::Constant>(element::f32, Shape{1}, std::vector<float>{0});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto constant = std::dynamic_pointer_cast<ngraph::op::Constant>(m.get_match_root());
if (!constant) {
return false;
@ -54,16 +48,18 @@ private:
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertPrecision");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertConstantsPrecision");
register_matcher(m, callback);
}
};
void convert_parameters_precision() {
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
class ConvertParametersPrecision : public MatcherPass {
public:
ConvertParametersPrecision() {
auto constant = std::make_shared<ngraph::op::Parameter>(to, Shape{1});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto parameter = std::dynamic_pointer_cast<ngraph::op::Parameter>(m.get_match_root());
if (parameter && parameter->get_element_type() == ngraph::element::Type(from)) {
parameter->set_element_type(to);
@ -72,10 +68,17 @@ private:
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertPrecision");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertParametersPrecision");
register_matcher(m, callback);
}
};
template<ngraph::element::Type_t from, ngraph::element::Type_t to>
class ConvertPrecision : public ngraph::pass::GraphRewrite {
public:
ConvertPrecision() {
add_matcher<ConvertConstantsPrecision<from, to>>();
add_matcher<ConvertParametersPrecision<from, to>>();
}
};
} // namespace pass