Enable common optimizations on MO (#6245)

* Enable common optimizations on MO

* Added tensor name tracking; updated tests

* Disable DilatedConvolution transform

* Fix TopK3 transformation

* Codestyle fix

* Update tensor name logic

* Fix scatter nd shape inference for dynamic shape

* Update FrameworkNode to propagate dynamic output shape

* Enable HSwish in MO that is missing in nGrpah

* Cleanup MO and IE code

* Fix review comments

* Fix unit test
This commit is contained in:
Gleb Kazantaev 2021-07-09 00:38:07 +03:00 committed by GitHub
parent 8b52a4c0c5
commit 2a970a56d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 352 additions and 777 deletions

View File

@ -5,12 +5,54 @@
#include <memory>
#include "moc_transformations.hpp"
#include "pruning.hpp"
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/common_optimizations/gelu_fusion.hpp>
#include <transformations/common_optimizations/softplus_fusion.hpp>
#include <transformations/common_optimizations/softplus_to_mish_fusion.hpp>
#include <transformations/common_optimizations/swish_fusion.hpp>
#include <transformations/common_optimizations/remove_filtering_boxes_by_size.hpp>
#include <transformations/common_optimizations/hsigmoid_fusion.hpp>
#include <transformations/common_optimizations/hswish_fusion.hpp>
#include <transformations/common_optimizations/convert_quantize_dequantize.hpp>
#include <transformations/common_optimizations/pad_fusion.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
// To avoid issues with dynamism we make nGraph Function dynamic and after we apply all
// transformations we restore original shapes to the nGraph Function back
std::unordered_map<ngraph::op::Parameter*, PartialShape> input_shapes;
for (auto && param : f->get_parameters()) {
input_shapes[param.get()] = param->get_partial_shape();
param->set_partial_shape(PartialShape::dynamic(param->get_partial_shape().rank()));
}
f->validate_nodes_and_infer_types();
ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
common_fusions->add_matcher<ngraph::pass::HSwishFusion>();
common_fusions->add_matcher<ngraph::pass::HSigmoidFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.run_passes(f);
// Restore original shapes to the nGraph Function
for (auto && param : f->get_parameters()) {
param->set_partial_shape(input_shapes.at(param.get()));
}
f->validate_nodes_and_infer_types();
return false;
}

View File

@ -71,8 +71,11 @@ public:
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void cache_output_descriptor();
private:
std::vector<std::tuple<ngraph::PartialShape, ngraph::element::Type>> m_inputs_desc;
std::vector<std::tuple<ngraph::PartialShape, ngraph::element::Type>> m_output_desc;
FrameworkNodeAttrs m_attrs;
};

View File

@ -12,6 +12,7 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API PadFusion;
class TRANSFORMATIONS_API PadElimination;
class TRANSFORMATIONS_API PadFusionAvgPool;
class TRANSFORMATIONS_API PadFusionMaxPool;
class TRANSFORMATIONS_API PadFusionConvolution;
@ -22,6 +23,16 @@ class TRANSFORMATIONS_API PadFusionGroupConvolutionBackpropData;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief PadElimination eliminates pad that does nothing
*/
class ngraph::pass::PadElimination: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
PadElimination();
};
/**
* @ingroup ie_transformation_common_api
* @brief PadFusion transformation replaces following graph:
@ -113,5 +124,6 @@ public:
add_matcher<ngraph::pass::PadFusionConvolutionBackpropData>();
add_matcher<ngraph::pass::PadFusionGroupConvolution>();
add_matcher<ngraph::pass::PadFusionGroupConvolutionBackpropData>();
add_matcher<ngraph::pass::PadElimination>();
}
};

View File

@ -25,31 +25,71 @@ shared_ptr<Node> op::FrameworkNode::clone_with_new_inputs(const OutputVector& ne
return node;
}
void op::FrameworkNode::cache_output_descriptor() {
for (size_t i = 0; i < get_output_size(); ++i) {
m_output_desc.emplace_back(get_output_partial_shape(i), get_output_element_type(i));
}
}
void op::FrameworkNode::validate_and_infer_types() {
INTERNAL_OP_SCOPE(FrameworkNode_validate_and_infer_types);
// Save initial inputs descriptors
bool initialize_input_desc = m_inputs_desc.empty();
bool reset_output_shape_to_dynamic = false;
bool reset_output_shape_to_original = false;
for (uint64_t i = 0; i < get_input_size(); i++) {
// TODO: store constant values
const auto& new_input_desc =
std::make_tuple(get_input_partial_shape(i), get_input_element_type(i));
const auto& input_pshape = get_input_partial_shape(i);
const auto& input_type = get_input_element_type(i);
const auto& rank = input_pshape.rank();
const auto & get_error_message = [&]() {
std::stringstream out;
out << "Input descriptor for " << get_friendly_name()
<< " node has been changed:" << std::endl;
out << "Before: " << std::get<0>(m_inputs_desc[i]) << ", "
<< std::get<1>(m_inputs_desc[i]) << std::endl;
out << "After: " << input_pshape << ", "
<< input_type << std::endl;
out << "Please specify InferenceEngine Extensions to support this case.";
return out.str();
};
if (initialize_input_desc) {
m_inputs_desc.push_back(new_input_desc);
m_inputs_desc.emplace_back(input_pshape, input_type);
} else {
auto get_message = [&]() {
std::stringstream out;
out << "Input descriptor for " << get_friendly_name()
<< " node has been changed:" << std::endl;
out << "Before: " << std::get<0>(m_inputs_desc[i]) << ", "
<< std::get<1>(m_inputs_desc[i]) << std::endl;
out << "After: " << std::get<0>(new_input_desc) << ", "
<< std::get<1>(new_input_desc) << std::endl;
out << "Please specify InferenceEngine Extensions to support this case.";
return out.str();
};
const auto& orig_input_pshape = std::get<0>(m_inputs_desc[i]);
if (orig_input_pshape == input_pshape) {
reset_output_shape_to_original = true;
} else if (input_pshape.rank().is_dynamic()) {
reset_output_shape_to_dynamic = true;
} else if (rank.is_static() && orig_input_pshape.rank().is_static() &&
rank.get_length() == orig_input_pshape.rank().get_length()) {
for (int64_t dim = 0; dim < rank.get_length(); ++dim) {
NODE_VALIDATION_CHECK(this, input_pshape[dim].is_dynamic() ||
(orig_input_pshape[dim].is_static() &&
orig_input_pshape[dim].get_length() == input_pshape[dim].get_length()),
get_error_message());
}
reset_output_shape_to_dynamic = true;
} else {
NODE_VALIDATION_CHECK(this, m_inputs_desc[i] == std::make_tuple(input_pshape, input_type), get_error_message());
}
}
}
NODE_VALIDATION_CHECK(this, m_inputs_desc[i] == new_input_desc, get_message());
if (reset_output_shape_to_dynamic) {
cache_output_descriptor();
for (size_t i = 0; i < get_output_size(); ++i) {
if (get_output_partial_shape(i).rank().is_static()) {
set_output_type(i, get_output_element_type(i), PartialShape::dynamic());
}
}
}
if (reset_output_shape_to_original && !m_output_desc.empty()) {
for (size_t i = 0; i < get_output_size(); ++i) {
set_output_type(i, std::get<1>(m_output_desc[i]), std::get<0>(m_output_desc[i]));
}
}
}

View File

@ -12,6 +12,7 @@
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/validation_util.hpp>
using namespace ngraph;
@ -385,3 +386,34 @@ pass::PadFusionGroupConvolutionBackpropData::PadFusionGroupConvolutionBackpropDa
auto m = std::make_shared<pattern::Matcher>(conv_pattern, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(pass::PadElimination, "PadElimination", 0);
pass::PadElimination::PadElimination() {
MATCHER_SCOPE(PadElimination);
auto pad_node_pattern = pattern::wrap_type<opset5::Pad>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pad = m.get_match_root();
auto pad_begin_const = ngraph::get_constant_from_source(pad->input_value(1));
auto pad_end_const = ngraph::get_constant_from_source(pad->input_value(2));
if (!pad_begin_const || !pad_end_const) {
return false;
}
const auto pad_begin_value = pad_begin_const->cast_vector<int64_t>();
const auto pad_end_value = pad_end_const->cast_vector<int64_t>();
if (std::any_of(pad_begin_value.begin(), pad_begin_value.end(), [](int64_t value) { return value != 0; }) ||
std::any_of(pad_end_value.begin(), pad_end_value.end(), [](int64_t value) { return value != 0; })) {
return false;
}
return replace_output_update_name(pad->output(0), pad->input_value(0));
};
auto m = std::make_shared<pattern::Matcher>(pad_node_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -19,26 +19,20 @@ ngraph::pass::SoftPlusFusion::SoftPlusFusion() {
// fuses ln(exp(x) + 1.0) operations into SoftPlus(x)
auto input = ngraph::pattern::any_input();
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>(
pattern::type_matches_any({element::f32, element::f16}));
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
auto log = std::make_shared<ngraph::opset4::Log>(add);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
const auto &pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!constant) return false;
if (constant == nullptr) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || data[0] != 1.0) {
return false;
}
} else {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || data[0] != 1.0) {
return false;
}

View File

@ -40,6 +40,7 @@ ngraph::pass::ConvertTopK3::ConvertTopK3() {
last1 = new_topk->output(1);
new_topk->set_friendly_name(topk->get_friendly_name());
} else if (topk->get_output_target_inputs(0).size() == 0) {
last0 = topk->output(0);
last1 = std::make_shared<ngraph::opset2::Convert>(new_topk->output(1), topk->get_index_element_type());
new_ops.push_back(last1.get_node_shared_ptr());

View File

@ -0,0 +1,59 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph_ops/framework_node.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, FrameworkNode) {
auto param = std::make_shared<ngraph::opset8::Parameter>(element::i64, Shape{1, 64});
auto f_node = std::make_shared<ngraph::op::FrameworkNode>(OutputVector{param});
f_node->set_output_type(0, element::i64, Shape{1, 64});
// Set partially dynamic shape
param->set_partial_shape(PartialShape{Dimension::dynamic(), 64});
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
// Set dynamic shape
param->set_partial_shape(PartialShape::dynamic(2));
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
// Set fully dynamic shape
param->set_partial_shape(PartialShape::dynamic());
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
// Set original static shape
param->set_partial_shape(Shape{1, 64});
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape({1, 64}));
// Set different static shape
param->set_partial_shape(Shape{2, 64});
param->validate_and_infer_types();
ASSERT_THROW(f_node->validate_and_infer_types(), ngraph_error::exception);
}

View File

@ -14,6 +14,7 @@
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <common_test_utils/ngraph_test_utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -21,6 +22,42 @@
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, PadElimination) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
set_tensor_name(data, "param");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 0, 0});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 0, 0});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::i32, Shape{1, 3, 4, 4});
auto conv = std::make_shared<opset5::Convolution>(pad, filters, Strides{1, 1},
CoordinateDiff{0, 0}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::PadFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
set_tensor_names(data, {"param", "pad"});
auto filters = std::make_shared<opset5::Parameter>(element::i32, Shape{1, 3, 4, 4});
auto conv = std::make_shared<opset5::Convolution>(data, filters, Strides{1, 1},
CoordinateDiff{0, 0}, CoordinateDiff{1, 1}, Shape{1, 1},
op::PadType::EXPLICIT);
set_tensor_name(conv, "conv");
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, PadFusionAvgPoolExcludePad) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
@ -155,9 +192,11 @@ TEST(TransformationTests, PadFusionConvolutionBackpropData) {
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
set_tensor_name(conv, "conv");
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
pass::Manager m;
@ -171,6 +210,7 @@ TEST(TransformationTests, PadFusionConvolutionBackpropData) {
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}
@ -389,12 +429,15 @@ TEST(TransformationTests, NegativePadFusionConvolutionBackpropDataTooSmallPad) {
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
set_tensor_name(data, "data");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
pass::Manager m;
@ -405,12 +448,15 @@ TEST(TransformationTests, NegativePadFusionConvolutionBackpropDataTooSmallPad) {
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
set_tensor_name(data, "data");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}

View File

@ -779,6 +779,14 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
}
}
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string & name) {
output.get_tensor_ptr()->set_names({name});
}
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string> & names) {
output.get_tensor_ptr()->set_names(names);
}
NGRAPH_RTTI_DEFINITION(TestOpMultiOut, "TestOp", 0);
namespace attributes {

View File

@ -101,6 +101,10 @@ inline std::pair<bool, std::string> compare_functions(
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string & name);
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string> & names);
namespace ngraph {
namespace pass {
class InjectionPass;

View File

@ -175,7 +175,6 @@ extensions/front/kaldi/tdnn_component_replacer.py
extensions/front/LayerNorm.py
extensions/front/Log1p.py
extensions/front/MatMul_normalizer.py
extensions/front/Mish_fusion.py
extensions/front/MoveEmbeddedInputsToInputs.py
extensions/front/mxnet/__init__.py
extensions/front/mxnet/activation.py
@ -329,12 +328,10 @@ extensions/front/onnx/priorbox_clustered_ext.py
extensions/front/onnx/priorbox_ext.py
extensions/front/onnx/priorgridgenerator_ext.py
extensions/front/onnx/proposal_ext.py
extensions/front/onnx/quantize_dequantize_linear.py
extensions/front/onnx/quantize_ext.py
extensions/front/onnx/quantize_linear_ext.py
extensions/front/onnx/range_ext.py
extensions/front/onnx/reduce_ext.py
extensions/front/onnx/remove_filtering_boxes_by_size.py
extensions/front/onnx/reshape_ext.py
extensions/front/onnx/resize_ext.py
extensions/front/onnx/reverse_sequence_ext.py
@ -371,13 +368,11 @@ extensions/front/RollWithEmptyAxesReplacer.py
extensions/front/scatter_normalizer.py
extensions/front/SizeReplacer.py
extensions/front/softmax.py
extensions/front/Softplus_fusion.py
extensions/front/softsign_replacer.py
extensions/front/sparse_to_dense_replacer.py
extensions/front/split_normalizer.py
extensions/front/SqueezeNormalize.py
extensions/front/sub.py
extensions/front/Swish_fusion.py
extensions/front/tf/__init__.py
extensions/front/tf/activation_ext.py
extensions/front/tf/argmax_ext.py

View File

@ -1,48 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.front.Softplus_fusion import SoftplusFusion
from extensions.ops.activation_ops import Mish
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
class MishFusion(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(SoftPlus(x)).
"""
enabled = True
def run_after(self):
return [SoftplusFusion]
def pattern(self):
return dict(
nodes=[
('mul', dict(op='Mul')),
('tanh', dict(op='Tanh')),
('softplus', dict(op='SoftPlus')),
],
edges=[
('softplus', 'tanh'),
('tanh', 'mul'),
])
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
mul = match['mul']
mul_name = mul.soft_get('name', mul.id)
softplus = match['softplus']
# determine the input port of Mul which gets the 'input' node output
input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh')
# check that the same tensor provided as input to Mul and SoftPlus
if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source():
return
mish = Mish(graph, {}).create_node()
mish.in_port(0).connect(mul.in_port(input_port_idx).get_source())
mul.out_port(0).get_connection().set_source(mish.out_port(0))
rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])

View File

@ -1,43 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.activation_ops import SoftPlus
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
from mo.middle.pattern_match import check_value
class SoftplusFusion(FrontReplacementSubgraph):
"""
The transformation looks for the pattern for the Softplus function: Softplus(x) = ln(1 + e^x)
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('exp', dict(op='Exp')),
('add', dict(op='Add')),
('const_1', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0, atol=1e-6)))),
('ln', dict(op='Log')),
],
edges=[
('exp', 'add', {}),
('const_1', 'add', {}),
('add', 'ln', {}),
])
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
ln = match['ln']
exp = match['exp']
ln_name = ln.soft_get('name', ln.id)
softplus = SoftPlus(graph, {}).create_node()
softplus.in_port(0).connect(exp.in_port(0).get_source())
ln.out_port(0).get_connection().set_source(softplus.out_port(0))
rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])

View File

@ -1,87 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.activation_ops import Swish
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
class SwishWithSigmoidWithoutBeta(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x)
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('sigmoid', dict(op='Sigmoid')),
('mul', dict(op='Mul')),
],
edges=[
('sigmoid', 'mul', {}),
])
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
sigmoid = match['sigmoid']
mul = match['mul']
mul_name = mul.soft_get('name', mul.id)
# determine the input port of Mul which gets the 'input' node output
mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
# check that the same tensor provided as input to Mul and Sigmoid
if mul.in_port(mul_input_port_idx).get_source() != sigmoid.in_port(0).get_source():
return
swish = Swish(graph, {}).create_node()
swish.in_port(0).connect(sigmoid.in_port(0).get_source())
mul.out_port(0).get_connection().set_source(swish.out_port(0))
rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
class SwishWithSigmoidWithBeta(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x * beta)
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('sigmoid', dict(op='Sigmoid')),
('beta', dict()),
('mul_beta', dict(op='Mul')),
('mul', dict(op='Mul')),
],
edges=[
('beta', 'mul_beta', {}),
('mul_beta', 'sigmoid', {}),
('sigmoid', 'mul', {}),
])
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
beta = match['beta']
mul = match['mul']
mul_beta = match['mul_beta']
mul_name = mul.soft_get('name', mul.id)
# determine the input port of Muls which get the 'input' node output
mul_beta_input_port_idx = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id)
mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
# check that the same tensor provided as input to Mul and MulBeta
if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(mul_beta_input_port_idx).get_source():
return
swish = Swish(graph, {}).create_node()
swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port_idx).get_source())
# connect Beta value
swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port_idx).get_source())
mul.out_port(0).get_connection().set_source(swish.out_port(0))
rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])

View File

@ -1,87 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from extensions.ops.fakequantize import FakeQuantize
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
from mo.ops.const import Const
from mo.utils.error import Error
class QuantizeDequantizeLinear(FrontReplacementSubgraph):
"""
Fuses QuantizeLinear and DequantizeLinear nodes into single FakeQuantize.
Covers cases when the values for zero point and scale are same in both QuantizeLinear and DequantizeLinear.
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('quantize', dict(op='QuantizeLinear')),
('dequantize', dict(op='DequantizeLinear')),
],
edges=[
('quantize', 'dequantize', {'in': 0}),
]
)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
q = match['quantize']
dq = match['dequantize']
q_scale = q.in_port(1).get_source().node
q_zerop = q.in_port(2).get_source().node
dq_scale = dq.in_port(1).get_source().node
dq_zerop = dq.in_port(2).get_source().node
inp_port = q.in_port(0).get_source()
name = inp_port.node.soft_get('name', inp_port.node.id)
# only constant as for zero_point/scale supported
if q_scale.soft_get('type') == 'Const' and dq_scale.soft_get('type') == 'Const' and \
q_zerop.soft_get('type') == 'Const' and dq_zerop.soft_get('type') == 'Const':
# only patterns with same scale/zero_point values for Q and DQ are supported
if q_scale.value == dq_scale.value and q_zerop.value == dq_zerop.value:
log.debug('Found Q-DQ pattern after {}'.format(name))
zero_point_type = q_zerop.value.dtype
# data type affects range of output values: [-128..127] or [0..255]
if zero_point_type == np.int8:
output_min_value = -128.0
output_max_value = 127.0
elif zero_point_type == np.uint8:
output_min_value = 0.0
output_max_value = 255.0
else:
raise Error('Not supported type {} for zero point value in node {}'.format(
zero_point_type, q_zerop.soft_get('name')))
min_value = q_scale.value * (output_min_value - q_zerop.value)
max_value = q_scale.value * (output_max_value - q_zerop.value)
input_min = Const(graph, {'value': np.array(min_value)}).create_node()
input_max = Const(graph, {'value': np.array(max_value)}).create_node()
FQ = FakeQuantize(graph, {
'levels': 256,
'name': match['quantize'].name + '_Dequantize/FakeQuantize'
}).create_node()
FQ.in_port(0).connect(match['quantize'].in_port(0).get_source())
FQ.in_port(1).connect(input_min.out_port(0))
FQ.in_port(2).connect(input_max.out_port(0))
FQ.in_port(3).connect(input_min.out_port(0))
FQ.in_port(4).connect(input_max.out_port(0))
match['dequantize'].out_port(0).get_connection().set_source(FQ.out_port(0))
dq_name = match['dequantize'].soft_get('name', match['dequantize'].id)
rename_nodes([(match['dequantize'], dq_name + '/to_be_removed'), (FQ, dq_name)])
else:
raise Error('QuantizeLinear and DequantizeLinear (after {}) have different scale or zero-point values, '
'cannot fuse into FakeQuantize!'.format(name))

View File

@ -1,117 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.front.split_normalizer import AttributedVariadicSplitToVariadicSplit
from extensions.ops.range import Range
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.graph.graph import Node
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.squeeze import Squeeze
from mo.utils.shape import node_to_get_batch_value
def skip_nodes_by_condition(current_node: Node, condition: callable):
while condition(current_node):
current_node = current_node.in_node()
return current_node
class RemoveFilteringBoxesBySize(FrontReplacementSubgraph):
"""
The transformation looks for a sub-graph that selects boxes with nonzero height and width. The output node of this
sub-graph is a Cast node that produces indices of nodes to be preserved. The transformation creates a new sub-graph
that produces a tensor with values from 0 to input.shape[0] to select all boxes. The output of this sub-graph will
be used in the NonMaxSuppression so the implementation of this layer should ignore boxes with negative sizes.
"""
enabled = True
force_clean_up = True
def run_after(self):
return [AttributedVariadicSplitToVariadicSplit]
def pattern(self):
return dict(
nodes=[
('split', dict(op='VariadicSplit')),
('sub_1', dict(op='Sub')),
('sub_2', dict(op='Sub')),
('add_1', dict(op='Add')),
('add_2', dict(op='Add')),
('concat', dict(op='Concat')),
('split_2', dict(op='VariadicSplit')),
('squeeze_1', dict(op='Squeeze')),
('squeeze_2', dict(op='Squeeze')),
('less_1', dict(op='Less')),
('less_2', dict(op='Less')),
('not_1', dict(op='LogicalNot')),
('not_2', dict(op='LogicalNot')),
('cast_11', dict(op='Cast')),
('cast_12', dict(op='Cast')),
('cast_21', dict(op='Cast')),
('cast_22', dict(op='Cast')),
('and', dict(op='LogicalAnd')),
('cast_31', dict(op='Cast')),
('cast_32', dict(op='Cast')),
('nonzero', dict(op='NonZero')),
('transpose', dict(op='Transpose')),
('squeeze', dict(op='Squeeze')),
('cast', dict(op='Cast')),
],
edges=[
('split', 'sub_1', {'in': 0, 'out': 2}),
('split', 'sub_1', {'in': 1, 'out': 0}),
('split', 'sub_2', {'in': 0, 'out': 3}),
('split', 'sub_2', {'in': 1, 'out': 1}),
('sub_1', 'add_1', {}),
('sub_2', 'add_2', {}),
('split', 'concat', {'in': 0, 'out': 0}),
('split', 'concat', {'in': 1, 'out': 1}),
('add_1', 'concat', {'in': 2, 'out': 0}),
('add_2', 'concat', {'in': 3, 'out': 0}),
('concat', 'split_2', {}),
('split_2', 'squeeze_1', {'in': 0, 'out': 2}),
('split_2', 'squeeze_2', {'in': 0, 'out': 3}),
('squeeze_1', 'less_1', {}),
('squeeze_2', 'less_2', {}),
('less_1', 'not_1', {}),
('less_2', 'not_2', {}),
('not_1', 'cast_11', {}),
('cast_11', 'cast_12', {}),
('not_2', 'cast_21', {}),
('cast_21', 'cast_22', {}),
('cast_12', 'and', {}),
('cast_22', 'and', {}),
('and', 'cast_31', {}),
('cast_31', 'cast_32', {}),
('cast_32', 'nonzero', {}),
('nonzero', 'transpose', {}),
('transpose', 'squeeze', {}),
('squeeze', 'cast', {}),
])
def replace_sub_graph(self, graph: Graph, match: dict):
source_connection = match['split'].in_port(0).get_connection()
source_node = source_connection.get_source().node
cast_node = match['cast']
range_node = Range(graph, {'name': source_node.id + '/Range'}).create_node()
start_node = Const(graph, {'name': range_node.id + '/Start', 'value': int64_array(0)}).create_node()
step_node = Const(graph, {'name': range_node.id + '/Step', 'value': int64_array(1)}).create_node()
input_shape_node = Shape(graph, {'name': start_node.id + '/Shape'}).create_node()
input_shape_node.in_port(0).connect(source_node.out_port(0))
limit_node_1D = node_to_get_batch_value(input_shape_node)
limit_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
{'name': source_node.id + '/batch_0D_value'}, limit_node_1D)
range_node.in_port(0).connect(start_node.out_port(0))
range_node.in_port(1).connect(limit_node.out_port(0))
range_node.in_port(2).connect(step_node.out_port(0))
cast_node.out_port(0).get_connection().set_source(range_node.out_port(0))
graph.remove_nodes_from([node.id for node in match.values()])

View File

@ -1,66 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from extensions.front.Mish_fusion import MishFusion
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('mish', {'type': 'Mish', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'mish'), ('mish', 'result')]
class MishFusionTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('softplus', {'op': 'SoftPlus'}),
**regular_op('tanh', {'op': 'Tanh'}),
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}
edges = [('input', 'softplus', {'in': 0, 'out': 0}),
('input', 'mul', {'in': 0, 'out': 0}),
('softplus', 'tanh', {'in': 0, 'out': 0}),
('tanh', 'mul', {'in': 1, 'out': 0}),
('mul', 'result', {'in': 0, 'out': 0})]
def test_mish_fusion(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'
MishFusion().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Mish')
def test_mish_fusion_different_source(self):
# check case when different tensors goes to Mul and SoftPlus
graph = build_graph_with_edge_attrs({
**regular_op('input', {'type': 'Parameter'}),
**regular_op('input_2', {'type': 'Parameter'}),
**regular_op('softplus', {'op': 'SoftPlus'}),
**regular_op('tanh', {'op': 'Tanh'}),
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}, [('input', 'softplus', {'in': 0, 'out': 0}),
('input_2', 'mul', {'in': 0, 'out': 0}),
('softplus', 'tanh', {'in': 0, 'out': 0}),
('tanh', 'mul', {'in': 1, 'out': 0}),
('mul', 'result', {'in': 0, 'out': 0})], {})
graph_ref = graph.copy()
graph.stage = 'front'
MishFusion().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

View File

@ -1,57 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from extensions.front.Softplus_fusion import SoftplusFusion
from mo.front.common.partial_infer.utils import float_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_log'}),
**result('result')
}
ref_edges = [('input', 'softplus'), ('softplus', 'result')]
class SoftplusFusionTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('exp', {'op': 'Exp'}),
**const('const_1', float_array([1.0])),
**regular_op('add', {'op': 'Add'}),
**regular_op('ln', {'op': 'Log', 'name': 'final_log'}),
**result('result'),
}
edges = [('input', 'exp', {'in': 0, 'out': 0}),
('const_1', 'add', {'in': 0, 'out': 0}),
('exp', 'add', {'in': 1, 'out': 0}),
('add', 'ln', {'in': 0, 'out': 0}),
('ln', 'result', {'in': 0, 'out': 0})]
def test_softplus_fusion_test(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'
SoftplusFusion().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_log')) == 1 and
graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus')
def test_softplus_fusion_test_wrong_const(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}})
graph_ref = graph.copy()
graph.stage = 'front'
SoftplusFusion().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

View File

@ -1,119 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from extensions.front.Swish_fusion import SwishWithSigmoidWithoutBeta, SwishWithSigmoidWithBeta
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('swish', {'type': 'Swish', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'swish'), ('swish', 'result')]
class SwishWithSigmoidWithoutBetaTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('sigmoid', {'op': 'Sigmoid'}),
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}
edges = [('input', 'mul', {'in': 0, 'out': 0}),
('input', 'sigmoid', {'in': 0, 'out': 0}),
('sigmoid', 'mul', {'in': 1, 'out': 0}),
('mul', 'result', {'in': 0, 'out': 0})]
def test_swish_with_sigmoid_without_beta_test(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'
SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
def test_swish_with_sigmoid_without_beta_different_tensors(self):
graph = build_graph_with_edge_attrs({
**regular_op('input', {'type': 'Parameter'}),
**regular_op('input_2', {'type': 'Parameter'}),
**regular_op('sigmoid', {'op': 'Sigmoid'}),
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}, [('input_2', 'mul', {'in': 0, 'out': 0}),
('input', 'sigmoid', {'in': 0, 'out': 0}),
('sigmoid', 'mul', {'in': 1, 'out': 0}),
('mul', 'result', {'in': 0, 'out': 0})], {})
graph_ref = graph.copy()
graph.stage = 'front'
SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
class SwishWithSigmoidWithBetaTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('beta', {'type': 'Parameter'}),
**regular_op('mul_beta', {'op': 'Mul'}),
**regular_op('sigmoid', {'op': 'Sigmoid'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}
edges = [('input', 'mul_beta', {'in': 0, 'out': 0}),
('input', 'mul_2', {'in': 0, 'out': 0}),
('beta', 'mul_beta', {'in': 1, 'out': 0}),
('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})]
def test_swish_with_sigmoid_with_beta_test(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
new_ref_nodes = ref_nodes.copy()
new_ref_nodes.update(**regular_op('beta', {'type': 'Parameter'}))
graph_ref = build_graph(new_ref_nodes, ref_edges + [('beta', 'swish')])
graph.stage = 'front'
SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
def test_swish_with_sigmoid_with_beta_different_tensors(self):
graph = build_graph_with_edge_attrs({
**regular_op('input', {'type': 'Parameter'}),
**regular_op('input_2', {'type': 'Parameter'}),
**regular_op('beta', {'type': 'Parameter'}),
**regular_op('mul_beta', {'op': 'Mul'}),
**regular_op('sigmoid', {'op': 'Sigmoid'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}, [('input', 'mul_beta', {'in': 0, 'out': 0}),
('input_2', 'mul_2', {'in': 0, 'out': 0}),
('beta', 'mul_beta', {'in': 1, 'out': 0}),
('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})], {})
graph_ref = graph.copy()
graph.stage = 'front'
SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

View File

@ -1,115 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.front.onnx.quantize_dequantize_linear import QuantizeDequantizeLinear
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph
# quantize and dequantize share tensors with scale/zp
nodes0_attributes = {
'input': {'kind': 'op', 'op': 'AnyOp'},
'quantize': {'kind': 'op', 'op': 'QuantizeLinear'},
'dequantize': {'kind': 'op', 'op': 'DequantizeLinear'},
'scale_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'zerop_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}
# quantize and dequantize do not share tensors with scale/zp
nodes1_attributes = {
'input': {'kind': 'op', 'op': 'AnyOp'},
'quantize': {'kind': 'op', 'op': 'QuantizeLinear'},
'dequantize': {'kind': 'op', 'op': 'DequantizeLinear'},
'scale_param_q': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'zerop_param_q': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'scale_param_dq': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'zerop_param_dq': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}
nodes_ref_attributes = {
'input': {'kind': 'op', 'op': 'AnyOp'},
'fq': {'kind': 'op', 'op': 'FakeQuantize'},
'min_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'max_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
'out': {'kind': 'op', 'op': 'AnyOp'},
}
class TestQuantizeDeQuantize2FakeQuantize(unittest.TestCase):
def test_quantizedequantize2fakequantize_0(self):
# testing the code path with uint8 zero-point
graph = build_graph(nodes1_attributes,
[('input', 'quantize'),
('quantize', 'dequantize'),
('scale_param_q', 'quantize'),
('zerop_param_q', 'quantize'),
('scale_param_dq', 'dequantize'),
('zerop_param_dq', 'dequantize'),
('dequantize', 'out'),
],
{'scale_param_q': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
'zerop_param_q': {'shape': np.array([1]), 'value': np.uint8(0)},
'scale_param_dq': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
'zerop_param_dq': {'shape': np.array([1]), 'value': np.uint8(0)},
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_ref_attributes,
[('input', 'fq', {'in': 0}),
('min_param', 'fq', {'out': 0, 'in': 1}),
('min_param', 'fq', {'out': 0, 'in': 3}),
('max_param', 'fq', {'out': 0, 'in': 2}),
('max_param', 'fq', {'out': 0, 'in': 4}),
('fq', 'out'),
],
{'fq': {'levels': 256},
'min_param': {'value': np.float32(0.0)},
'max_param': {'value': np.float32(1.0)},
}, nodes_with_edges_only=True)
graph.stage = 'front'
tested_class = QuantizeDequantizeLinear()
tested_class.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_quantizedequantize2fakequantize_1(self):
# testing the code path with int8 zero-point
graph = build_graph(nodes0_attributes,
[('input', 'quantize'),
('quantize', 'dequantize'),
('scale_param', 'quantize'),
('zerop_param', 'quantize'),
('scale_param', 'dequantize'),
('zerop_param', 'dequantize'),
('dequantize', 'out'),
],
{'scale_param': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
'zerop_param': {'shape': np.array([1]), 'value': np.int8(0)},
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_ref_attributes,
[('input', 'fq', {'in': 0}),
('min_param', 'fq', {'out': 0, 'in': 1}),
('min_param', 'fq', {'out': 0, 'in': 3}),
('max_param', 'fq', {'out': 0, 'in': 2}),
('max_param', 'fq', {'out': 0, 'in': 4}),
('fq', 'out'),
],
{'fq': {'levels': 256},
'min_param': {'value': np.float32(-128.0 / 255)},
'max_param': {'value': np.float32(127.0 / 255)},
}, nodes_with_edges_only=True)
graph.stage = 'front'
tested_class = QuantizeDequantizeLinear()
tested_class.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -45,6 +45,7 @@ namespace ngraph
const std::unordered_set<std::string>& get_names() const;
void set_names(const std::unordered_set<std::string>& names);
void add_names(const std::unordered_set<std::string>& names);
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
void set_element_type(const element::Type& elemenet_type);
void set_partial_shape(const PartialShape& partial_shape);

View File

@ -123,6 +123,14 @@ void descriptor::Tensor::set_names(const std::unordered_set<std::string>& names)
m_names = names;
}
void descriptor::Tensor::add_names(const std::unordered_set<std::string>& names)
{
for (const auto& name : names)
{
m_names.insert(name);
}
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{
std::string names;

View File

@ -167,11 +167,9 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
// Change I's connected upstream output to O_rep
for (size_t i = 0; i < target->get_output_size(); i++)
{
for (auto& input : target->output(i).get_target_inputs())
{
input.replace_source_output(replacement->output(output_order[i]));
}
target->output(i).replace(replacement->output(output_order[i]));
}
replacement->add_node_control_dependents(target);
replacement->add_node_control_dependencies(target);
target->clear_control_dependents();
@ -912,7 +910,15 @@ bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>&
replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
NGRAPH_SUPPRESS_DEPRECATED_END
}
// Save replacement tensor names before replacement as they will be
// overrided by the output tensor names
auto output_names = replacement.get_tensor_ptr()->get_names();
output.replace(replacement);
// Restore back original replacement tensor names
replacement.get_tensor().add_names(output_names);
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
replacement.get_node_shared_ptr());
return true;

View File

@ -76,6 +76,7 @@ namespace ngraph
{
input.replace_source_output(replacement);
}
replacement.get_tensor_ptr()->set_names(get_tensor_ptr()->get_names());
}
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;

View File

@ -57,11 +57,13 @@ void op::util::ScatterNDBase::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
inputs_rank.is_dynamic() || indices_rank.is_dynamic() ||
indices_shape[indices_rank.get_length() - 1].is_dynamic() ||
indices_shape[indices_rank.get_length() - 1].get_length() <=
inputs_rank.get_length(),
"Last dimension of indices can be at most the rank of inputs");
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static())
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static() &&
indices_shape[indices_rank.get_length() - 1].is_static())
{
auto expected_updates_rank = indices_rank.get_length() + inputs_rank.get_length() -
indices_shape[indices_rank.get_length() - 1].get_length() - 1;

View File

@ -108,3 +108,63 @@ TEST(replace_node, replace_nodes)
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(0), x_replacement);
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(1), mul);
}
TEST(replace_node, simple_node_replacement)
{
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
param->output(0).get_tensor().set_names({"a", "b"});
auto relu = std::make_shared<op::Relu>(param);
relu->output(0).get_tensor().set_names({"c", "d"});
auto new_relu = std::make_shared<op::Relu>(param);
new_relu->output(0).get_tensor().set_names({"f"});
replace_node(relu, new_relu);
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
}
TEST(replace_node, node_elimination)
{
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
param->output(0).get_tensor().set_names({"a", "b"});
auto relu1 = std::make_shared<op::Relu>(param);
relu1->output(0).get_tensor().set_names({"c", "d"});
auto relu2 = std::make_shared<op::Relu>(relu1);
relu2->output(0).get_tensor().set_names({"e", "f"});
ASSERT_TRUE(replace_output_update_name(relu2->output(0), relu2->input_value(0)));
ASSERT_EQ(relu1->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d", "e", "f"}));
ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"a", "b"}));
}
TEST(replace_node, output_replacement)
{
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
param->output(0).get_tensor().set_names({"a", "b"});
auto relu = std::make_shared<op::Relu>(param);
relu->output(0).get_tensor().set_names({"c", "d"});
auto new_relu = std::make_shared<op::Relu>(param);
new_relu->output(0).get_tensor().set_names({"f"});
relu->output(0).replace(new_relu->output(0));
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
}
TEST(replace_node, source_replacement)
{
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
param->output(0).get_tensor().set_names({"a", "b"});
auto param1 = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
param1->output(0).get_tensor().set_names({"c", "d"});
auto relu = std::make_shared<op::Relu>(param);
relu->input(0).replace_source_output(param1->output(0));
ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"a", "b"}));
ASSERT_EQ(param1->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
}