Enable common optimizations on MO (#6245)
* Enable common optimizations on MO * Added tensor name tracking; updated tests * Disable DilatedConvolution transform * Fix TopK3 transformation * Codestyle fix * Update tensor name logic * Fix scatter nd shape inference for dynamic shape * Update FrameworkNode to propagate dynamic output shape * Enable HSwish in MO that is missing in nGrpah * Cleanup MO and IE code * Fix review comments * Fix unit test
This commit is contained in:
parent
8b52a4c0c5
commit
2a970a56d3
@ -5,12 +5,54 @@
|
||||
#include <memory>
|
||||
|
||||
#include "moc_transformations.hpp"
|
||||
#include "pruning.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/common_optimizations/gelu_fusion.hpp>
|
||||
#include <transformations/common_optimizations/softplus_fusion.hpp>
|
||||
#include <transformations/common_optimizations/softplus_to_mish_fusion.hpp>
|
||||
#include <transformations/common_optimizations/swish_fusion.hpp>
|
||||
#include <transformations/common_optimizations/remove_filtering_boxes_by_size.hpp>
|
||||
#include <transformations/common_optimizations/hsigmoid_fusion.hpp>
|
||||
#include <transformations/common_optimizations/hswish_fusion.hpp>
|
||||
#include <transformations/common_optimizations/convert_quantize_dequantize.hpp>
|
||||
#include <transformations/common_optimizations/pad_fusion.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
||||
|
||||
bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
// To avoid issues with dynamism we make nGraph Function dynamic and after we apply all
|
||||
// transformations we restore original shapes to the nGraph Function back
|
||||
std::unordered_map<ngraph::op::Parameter*, PartialShape> input_shapes;
|
||||
for (auto && param : f->get_parameters()) {
|
||||
input_shapes[param.get()] = param->get_partial_shape();
|
||||
param->set_partial_shape(PartialShape::dynamic(param->get_partial_shape().rank()));
|
||||
}
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
|
||||
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::HSwishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::HSigmoidFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::PadFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
// Restore original shapes to the nGraph Function
|
||||
for (auto && param : f->get_parameters()) {
|
||||
param->set_partial_shape(input_shapes.at(param.get()));
|
||||
}
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
return false;
|
||||
}
|
@ -71,8 +71,11 @@ public:
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
void cache_output_descriptor();
|
||||
|
||||
private:
|
||||
std::vector<std::tuple<ngraph::PartialShape, ngraph::element::Type>> m_inputs_desc;
|
||||
std::vector<std::tuple<ngraph::PartialShape, ngraph::element::Type>> m_output_desc;
|
||||
|
||||
FrameworkNodeAttrs m_attrs;
|
||||
};
|
||||
|
@ -12,6 +12,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API PadFusion;
|
||||
class TRANSFORMATIONS_API PadElimination;
|
||||
class TRANSFORMATIONS_API PadFusionAvgPool;
|
||||
class TRANSFORMATIONS_API PadFusionMaxPool;
|
||||
class TRANSFORMATIONS_API PadFusionConvolution;
|
||||
@ -22,6 +23,16 @@ class TRANSFORMATIONS_API PadFusionGroupConvolutionBackpropData;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PadElimination eliminates pad that does nothing
|
||||
*/
|
||||
class ngraph::pass::PadElimination: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PadElimination();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PadFusion transformation replaces following graph:
|
||||
@ -113,5 +124,6 @@ public:
|
||||
add_matcher<ngraph::pass::PadFusionConvolutionBackpropData>();
|
||||
add_matcher<ngraph::pass::PadFusionGroupConvolution>();
|
||||
add_matcher<ngraph::pass::PadFusionGroupConvolutionBackpropData>();
|
||||
add_matcher<ngraph::pass::PadElimination>();
|
||||
}
|
||||
};
|
||||
|
@ -25,31 +25,71 @@ shared_ptr<Node> op::FrameworkNode::clone_with_new_inputs(const OutputVector& ne
|
||||
return node;
|
||||
}
|
||||
|
||||
void op::FrameworkNode::cache_output_descriptor() {
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
m_output_desc.emplace_back(get_output_partial_shape(i), get_output_element_type(i));
|
||||
}
|
||||
}
|
||||
|
||||
void op::FrameworkNode::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(FrameworkNode_validate_and_infer_types);
|
||||
// Save initial inputs descriptors
|
||||
bool initialize_input_desc = m_inputs_desc.empty();
|
||||
bool reset_output_shape_to_dynamic = false;
|
||||
bool reset_output_shape_to_original = false;
|
||||
for (uint64_t i = 0; i < get_input_size(); i++) {
|
||||
// TODO: store constant values
|
||||
const auto& new_input_desc =
|
||||
std::make_tuple(get_input_partial_shape(i), get_input_element_type(i));
|
||||
const auto& input_pshape = get_input_partial_shape(i);
|
||||
const auto& input_type = get_input_element_type(i);
|
||||
const auto& rank = input_pshape.rank();
|
||||
|
||||
const auto & get_error_message = [&]() {
|
||||
std::stringstream out;
|
||||
out << "Input descriptor for " << get_friendly_name()
|
||||
<< " node has been changed:" << std::endl;
|
||||
out << "Before: " << std::get<0>(m_inputs_desc[i]) << ", "
|
||||
<< std::get<1>(m_inputs_desc[i]) << std::endl;
|
||||
out << "After: " << input_pshape << ", "
|
||||
<< input_type << std::endl;
|
||||
out << "Please specify InferenceEngine Extensions to support this case.";
|
||||
return out.str();
|
||||
};
|
||||
|
||||
if (initialize_input_desc) {
|
||||
m_inputs_desc.push_back(new_input_desc);
|
||||
m_inputs_desc.emplace_back(input_pshape, input_type);
|
||||
} else {
|
||||
auto get_message = [&]() {
|
||||
std::stringstream out;
|
||||
out << "Input descriptor for " << get_friendly_name()
|
||||
<< " node has been changed:" << std::endl;
|
||||
out << "Before: " << std::get<0>(m_inputs_desc[i]) << ", "
|
||||
<< std::get<1>(m_inputs_desc[i]) << std::endl;
|
||||
out << "After: " << std::get<0>(new_input_desc) << ", "
|
||||
<< std::get<1>(new_input_desc) << std::endl;
|
||||
out << "Please specify InferenceEngine Extensions to support this case.";
|
||||
return out.str();
|
||||
};
|
||||
const auto& orig_input_pshape = std::get<0>(m_inputs_desc[i]);
|
||||
if (orig_input_pshape == input_pshape) {
|
||||
reset_output_shape_to_original = true;
|
||||
} else if (input_pshape.rank().is_dynamic()) {
|
||||
reset_output_shape_to_dynamic = true;
|
||||
} else if (rank.is_static() && orig_input_pshape.rank().is_static() &&
|
||||
rank.get_length() == orig_input_pshape.rank().get_length()) {
|
||||
for (int64_t dim = 0; dim < rank.get_length(); ++dim) {
|
||||
NODE_VALIDATION_CHECK(this, input_pshape[dim].is_dynamic() ||
|
||||
(orig_input_pshape[dim].is_static() &&
|
||||
orig_input_pshape[dim].get_length() == input_pshape[dim].get_length()),
|
||||
get_error_message());
|
||||
}
|
||||
reset_output_shape_to_dynamic = true;
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(this, m_inputs_desc[i] == std::make_tuple(input_pshape, input_type), get_error_message());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this, m_inputs_desc[i] == new_input_desc, get_message());
|
||||
if (reset_output_shape_to_dynamic) {
|
||||
cache_output_descriptor();
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
if (get_output_partial_shape(i).rank().is_static()) {
|
||||
set_output_type(i, get_output_element_type(i), PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (reset_output_shape_to_original && !m_output_desc.empty()) {
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
set_output_type(i, std::get<1>(m_output_desc[i]), std::get<0>(m_output_desc[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@ -385,3 +386,34 @@ pass::PadFusionGroupConvolutionBackpropData::PadFusionGroupConvolutionBackpropDa
|
||||
auto m = std::make_shared<pattern::Matcher>(conv_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(pass::PadElimination, "PadElimination", 0);
|
||||
|
||||
pass::PadElimination::PadElimination() {
|
||||
MATCHER_SCOPE(PadElimination);
|
||||
auto pad_node_pattern = pattern::wrap_type<opset5::Pad>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pad = m.get_match_root();
|
||||
|
||||
auto pad_begin_const = ngraph::get_constant_from_source(pad->input_value(1));
|
||||
auto pad_end_const = ngraph::get_constant_from_source(pad->input_value(2));
|
||||
|
||||
if (!pad_begin_const || !pad_end_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto pad_begin_value = pad_begin_const->cast_vector<int64_t>();
|
||||
const auto pad_end_value = pad_end_const->cast_vector<int64_t>();
|
||||
|
||||
if (std::any_of(pad_begin_value.begin(), pad_begin_value.end(), [](int64_t value) { return value != 0; }) ||
|
||||
std::any_of(pad_end_value.begin(), pad_end_value.end(), [](int64_t value) { return value != 0; })) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return replace_output_update_name(pad->output(0), pad->input_value(0));
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(pad_node_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -19,26 +19,20 @@ ngraph::pass::SoftPlusFusion::SoftPlusFusion() {
|
||||
// fuses ln(exp(x) + 1.0) operations into SoftPlus(x)
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>(
|
||||
pattern::type_matches_any({element::f32, element::f16}));
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
|
||||
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
const auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
if (!constant) return false;
|
||||
|
||||
if (constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
|
||||
auto data = constant->cast_vector<float>();
|
||||
if (data.size() != 1 || data[0] != 1.0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto data = constant->cast_vector<float>();
|
||||
if (data.size() != 1 || data[0] != 1.0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,7 @@ ngraph::pass::ConvertTopK3::ConvertTopK3() {
|
||||
last1 = new_topk->output(1);
|
||||
new_topk->set_friendly_name(topk->get_friendly_name());
|
||||
} else if (topk->get_output_target_inputs(0).size() == 0) {
|
||||
last0 = topk->output(0);
|
||||
last1 = std::make_shared<ngraph::opset2::Convert>(new_topk->output(1), topk->get_index_element_type());
|
||||
new_ops.push_back(last1.get_node_shared_ptr());
|
||||
|
||||
|
@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph_ops/framework_node.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
|
||||
TEST(TransformationTests, FrameworkNode) {
|
||||
auto param = std::make_shared<ngraph::opset8::Parameter>(element::i64, Shape{1, 64});
|
||||
auto f_node = std::make_shared<ngraph::op::FrameworkNode>(OutputVector{param});
|
||||
f_node->set_output_type(0, element::i64, Shape{1, 64});
|
||||
|
||||
// Set partially dynamic shape
|
||||
param->set_partial_shape(PartialShape{Dimension::dynamic(), 64});
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
|
||||
// Set dynamic shape
|
||||
param->set_partial_shape(PartialShape::dynamic(2));
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
|
||||
// Set fully dynamic shape
|
||||
param->set_partial_shape(PartialShape::dynamic());
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
|
||||
// Set original static shape
|
||||
param->set_partial_shape(Shape{1, 64});
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), PartialShape({1, 64}));
|
||||
|
||||
// Set different static shape
|
||||
param->set_partial_shape(Shape{2, 64});
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_THROW(f_node->validate_and_infer_types(), ngraph_error::exception);
|
||||
}
|
@ -14,6 +14,7 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -21,6 +22,42 @@
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(TransformationTests, PadElimination) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
|
||||
set_tensor_name(data, "param");
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 0, 0});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 0, 0});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::i32, Shape{1, 3, 4, 4});
|
||||
auto conv = std::make_shared<opset5::Convolution>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{0, 0}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::PadFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::i32, data_shape);
|
||||
set_tensor_names(data, {"param", "pad"});
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::i32, Shape{1, 3, 4, 4});
|
||||
auto conv = std::make_shared<opset5::Convolution>(data, filters, Strides{1, 1},
|
||||
CoordinateDiff{0, 0}, CoordinateDiff{1, 1}, Shape{1, 1},
|
||||
op::PadType::EXPLICIT);
|
||||
set_tensor_name(conv, "conv");
|
||||
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PadFusionAvgPoolExcludePad) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
@ -155,9 +192,11 @@ TEST(TransformationTests, PadFusionConvolutionBackpropData) {
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
pass::Manager m;
|
||||
@ -171,6 +210,7 @@ TEST(TransformationTests, PadFusionConvolutionBackpropData) {
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
||||
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
@ -389,12 +429,15 @@ TEST(TransformationTests, NegativePadFusionConvolutionBackpropDataTooSmallPad) {
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
set_tensor_name(data, "data");
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
pass::Manager m;
|
||||
@ -405,12 +448,15 @@ TEST(TransformationTests, NegativePadFusionConvolutionBackpropDataTooSmallPad) {
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
set_tensor_name(data, "data");
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
|
@ -779,6 +779,14 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
|
||||
}
|
||||
}
|
||||
|
||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string & name) {
|
||||
output.get_tensor_ptr()->set_names({name});
|
||||
}
|
||||
|
||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string> & names) {
|
||||
output.get_tensor_ptr()->set_names(names);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(TestOpMultiOut, "TestOp", 0);
|
||||
|
||||
namespace attributes {
|
||||
|
@ -101,6 +101,10 @@ inline std::pair<bool, std::string> compare_functions(
|
||||
|
||||
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
|
||||
|
||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string & name);
|
||||
|
||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string> & names);
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class InjectionPass;
|
||||
|
@ -175,7 +175,6 @@ extensions/front/kaldi/tdnn_component_replacer.py
|
||||
extensions/front/LayerNorm.py
|
||||
extensions/front/Log1p.py
|
||||
extensions/front/MatMul_normalizer.py
|
||||
extensions/front/Mish_fusion.py
|
||||
extensions/front/MoveEmbeddedInputsToInputs.py
|
||||
extensions/front/mxnet/__init__.py
|
||||
extensions/front/mxnet/activation.py
|
||||
@ -329,12 +328,10 @@ extensions/front/onnx/priorbox_clustered_ext.py
|
||||
extensions/front/onnx/priorbox_ext.py
|
||||
extensions/front/onnx/priorgridgenerator_ext.py
|
||||
extensions/front/onnx/proposal_ext.py
|
||||
extensions/front/onnx/quantize_dequantize_linear.py
|
||||
extensions/front/onnx/quantize_ext.py
|
||||
extensions/front/onnx/quantize_linear_ext.py
|
||||
extensions/front/onnx/range_ext.py
|
||||
extensions/front/onnx/reduce_ext.py
|
||||
extensions/front/onnx/remove_filtering_boxes_by_size.py
|
||||
extensions/front/onnx/reshape_ext.py
|
||||
extensions/front/onnx/resize_ext.py
|
||||
extensions/front/onnx/reverse_sequence_ext.py
|
||||
@ -371,13 +368,11 @@ extensions/front/RollWithEmptyAxesReplacer.py
|
||||
extensions/front/scatter_normalizer.py
|
||||
extensions/front/SizeReplacer.py
|
||||
extensions/front/softmax.py
|
||||
extensions/front/Softplus_fusion.py
|
||||
extensions/front/softsign_replacer.py
|
||||
extensions/front/sparse_to_dense_replacer.py
|
||||
extensions/front/split_normalizer.py
|
||||
extensions/front/SqueezeNormalize.py
|
||||
extensions/front/sub.py
|
||||
extensions/front/Swish_fusion.py
|
||||
extensions/front/tf/__init__.py
|
||||
extensions/front/tf/activation_ext.py
|
||||
extensions/front/tf/argmax_ext.py
|
||||
|
@ -1,48 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.front.Softplus_fusion import SoftplusFusion
|
||||
from extensions.ops.activation_ops import Mish
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
|
||||
|
||||
class MishFusion(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(SoftPlus(x)).
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
return [SoftplusFusion]
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('mul', dict(op='Mul')),
|
||||
('tanh', dict(op='Tanh')),
|
||||
('softplus', dict(op='SoftPlus')),
|
||||
],
|
||||
edges=[
|
||||
('softplus', 'tanh'),
|
||||
('tanh', 'mul'),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
mul = match['mul']
|
||||
mul_name = mul.soft_get('name', mul.id)
|
||||
softplus = match['softplus']
|
||||
|
||||
# determine the input port of Mul which gets the 'input' node output
|
||||
input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh')
|
||||
|
||||
# check that the same tensor provided as input to Mul and SoftPlus
|
||||
if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source():
|
||||
return
|
||||
|
||||
mish = Mish(graph, {}).create_node()
|
||||
mish.in_port(0).connect(mul.in_port(input_port_idx).get_source())
|
||||
mul.out_port(0).get_connection().set_source(mish.out_port(0))
|
||||
|
||||
rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])
|
@ -1,43 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import SoftPlus
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.middle.pattern_match import check_value
|
||||
|
||||
|
||||
class SoftplusFusion(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for the pattern for the Softplus function: Softplus(x) = ln(1 + e^x)
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('exp', dict(op='Exp')),
|
||||
('add', dict(op='Add')),
|
||||
('const_1', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0, atol=1e-6)))),
|
||||
('ln', dict(op='Log')),
|
||||
],
|
||||
edges=[
|
||||
('exp', 'add', {}),
|
||||
('const_1', 'add', {}),
|
||||
('add', 'ln', {}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
ln = match['ln']
|
||||
exp = match['exp']
|
||||
|
||||
ln_name = ln.soft_get('name', ln.id)
|
||||
|
||||
softplus = SoftPlus(graph, {}).create_node()
|
||||
softplus.in_port(0).connect(exp.in_port(0).get_source())
|
||||
ln.out_port(0).get_connection().set_source(softplus.out_port(0))
|
||||
|
||||
rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])
|
@ -1,87 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.activation_ops import Swish
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
|
||||
|
||||
class SwishWithSigmoidWithoutBeta(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x)
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('sigmoid', dict(op='Sigmoid')),
|
||||
('mul', dict(op='Mul')),
|
||||
],
|
||||
edges=[
|
||||
('sigmoid', 'mul', {}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
sigmoid = match['sigmoid']
|
||||
mul = match['mul']
|
||||
mul_name = mul.soft_get('name', mul.id)
|
||||
|
||||
# determine the input port of Mul which gets the 'input' node output
|
||||
mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
|
||||
|
||||
# check that the same tensor provided as input to Mul and Sigmoid
|
||||
if mul.in_port(mul_input_port_idx).get_source() != sigmoid.in_port(0).get_source():
|
||||
return
|
||||
|
||||
swish = Swish(graph, {}).create_node()
|
||||
swish.in_port(0).connect(sigmoid.in_port(0).get_source())
|
||||
mul.out_port(0).get_connection().set_source(swish.out_port(0))
|
||||
|
||||
rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
|
||||
|
||||
|
||||
class SwishWithSigmoidWithBeta(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x * beta)
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('sigmoid', dict(op='Sigmoid')),
|
||||
('beta', dict()),
|
||||
('mul_beta', dict(op='Mul')),
|
||||
('mul', dict(op='Mul')),
|
||||
],
|
||||
edges=[
|
||||
('beta', 'mul_beta', {}),
|
||||
('mul_beta', 'sigmoid', {}),
|
||||
('sigmoid', 'mul', {}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
beta = match['beta']
|
||||
mul = match['mul']
|
||||
mul_beta = match['mul_beta']
|
||||
mul_name = mul.soft_get('name', mul.id)
|
||||
|
||||
# determine the input port of Muls which get the 'input' node output
|
||||
mul_beta_input_port_idx = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id)
|
||||
mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
|
||||
|
||||
# check that the same tensor provided as input to Mul and MulBeta
|
||||
if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(mul_beta_input_port_idx).get_source():
|
||||
return
|
||||
|
||||
swish = Swish(graph, {}).create_node()
|
||||
swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port_idx).get_source())
|
||||
|
||||
# connect Beta value
|
||||
swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port_idx).get_source())
|
||||
|
||||
mul.out_port(0).get_connection().set_source(swish.out_port(0))
|
||||
|
||||
rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
|
@ -1,87 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.fakequantize import FakeQuantize
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class QuantizeDequantizeLinear(FrontReplacementSubgraph):
|
||||
"""
|
||||
Fuses QuantizeLinear and DequantizeLinear nodes into single FakeQuantize.
|
||||
Covers cases when the values for zero point and scale are same in both QuantizeLinear and DequantizeLinear.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('quantize', dict(op='QuantizeLinear')),
|
||||
('dequantize', dict(op='DequantizeLinear')),
|
||||
],
|
||||
edges=[
|
||||
('quantize', 'dequantize', {'in': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
|
||||
q = match['quantize']
|
||||
dq = match['dequantize']
|
||||
|
||||
q_scale = q.in_port(1).get_source().node
|
||||
q_zerop = q.in_port(2).get_source().node
|
||||
dq_scale = dq.in_port(1).get_source().node
|
||||
dq_zerop = dq.in_port(2).get_source().node
|
||||
|
||||
inp_port = q.in_port(0).get_source()
|
||||
name = inp_port.node.soft_get('name', inp_port.node.id)
|
||||
|
||||
# only constant as for zero_point/scale supported
|
||||
if q_scale.soft_get('type') == 'Const' and dq_scale.soft_get('type') == 'Const' and \
|
||||
q_zerop.soft_get('type') == 'Const' and dq_zerop.soft_get('type') == 'Const':
|
||||
|
||||
# only patterns with same scale/zero_point values for Q and DQ are supported
|
||||
if q_scale.value == dq_scale.value and q_zerop.value == dq_zerop.value:
|
||||
log.debug('Found Q-DQ pattern after {}'.format(name))
|
||||
|
||||
zero_point_type = q_zerop.value.dtype
|
||||
# data type affects range of output values: [-128..127] or [0..255]
|
||||
if zero_point_type == np.int8:
|
||||
output_min_value = -128.0
|
||||
output_max_value = 127.0
|
||||
elif zero_point_type == np.uint8:
|
||||
output_min_value = 0.0
|
||||
output_max_value = 255.0
|
||||
else:
|
||||
raise Error('Not supported type {} for zero point value in node {}'.format(
|
||||
zero_point_type, q_zerop.soft_get('name')))
|
||||
min_value = q_scale.value * (output_min_value - q_zerop.value)
|
||||
max_value = q_scale.value * (output_max_value - q_zerop.value)
|
||||
input_min = Const(graph, {'value': np.array(min_value)}).create_node()
|
||||
input_max = Const(graph, {'value': np.array(max_value)}).create_node()
|
||||
|
||||
FQ = FakeQuantize(graph, {
|
||||
'levels': 256,
|
||||
'name': match['quantize'].name + '_Dequantize/FakeQuantize'
|
||||
}).create_node()
|
||||
|
||||
FQ.in_port(0).connect(match['quantize'].in_port(0).get_source())
|
||||
FQ.in_port(1).connect(input_min.out_port(0))
|
||||
FQ.in_port(2).connect(input_max.out_port(0))
|
||||
FQ.in_port(3).connect(input_min.out_port(0))
|
||||
FQ.in_port(4).connect(input_max.out_port(0))
|
||||
|
||||
match['dequantize'].out_port(0).get_connection().set_source(FQ.out_port(0))
|
||||
dq_name = match['dequantize'].soft_get('name', match['dequantize'].id)
|
||||
rename_nodes([(match['dequantize'], dq_name + '/to_be_removed'), (FQ, dq_name)])
|
||||
else:
|
||||
raise Error('QuantizeLinear and DequantizeLinear (after {}) have different scale or zero-point values, '
|
||||
'cannot fuse into FakeQuantize!'.format(name))
|
@ -1,117 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.front.split_normalizer import AttributedVariadicSplitToVariadicSplit
|
||||
from extensions.ops.range import Range
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Node
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.utils.shape import node_to_get_batch_value
|
||||
|
||||
|
||||
def skip_nodes_by_condition(current_node: Node, condition: callable):
|
||||
while condition(current_node):
|
||||
current_node = current_node.in_node()
|
||||
return current_node
|
||||
|
||||
|
||||
class RemoveFilteringBoxesBySize(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for a sub-graph that selects boxes with nonzero height and width. The output node of this
|
||||
sub-graph is a Cast node that produces indices of nodes to be preserved. The transformation creates a new sub-graph
|
||||
that produces a tensor with values from 0 to input.shape[0] to select all boxes. The output of this sub-graph will
|
||||
be used in the NonMaxSuppression so the implementation of this layer should ignore boxes with negative sizes.
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [AttributedVariadicSplitToVariadicSplit]
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('split', dict(op='VariadicSplit')),
|
||||
('sub_1', dict(op='Sub')),
|
||||
('sub_2', dict(op='Sub')),
|
||||
('add_1', dict(op='Add')),
|
||||
('add_2', dict(op='Add')),
|
||||
('concat', dict(op='Concat')),
|
||||
('split_2', dict(op='VariadicSplit')),
|
||||
('squeeze_1', dict(op='Squeeze')),
|
||||
('squeeze_2', dict(op='Squeeze')),
|
||||
('less_1', dict(op='Less')),
|
||||
('less_2', dict(op='Less')),
|
||||
('not_1', dict(op='LogicalNot')),
|
||||
('not_2', dict(op='LogicalNot')),
|
||||
('cast_11', dict(op='Cast')),
|
||||
('cast_12', dict(op='Cast')),
|
||||
('cast_21', dict(op='Cast')),
|
||||
('cast_22', dict(op='Cast')),
|
||||
('and', dict(op='LogicalAnd')),
|
||||
('cast_31', dict(op='Cast')),
|
||||
('cast_32', dict(op='Cast')),
|
||||
('nonzero', dict(op='NonZero')),
|
||||
('transpose', dict(op='Transpose')),
|
||||
('squeeze', dict(op='Squeeze')),
|
||||
('cast', dict(op='Cast')),
|
||||
],
|
||||
edges=[
|
||||
('split', 'sub_1', {'in': 0, 'out': 2}),
|
||||
('split', 'sub_1', {'in': 1, 'out': 0}),
|
||||
('split', 'sub_2', {'in': 0, 'out': 3}),
|
||||
('split', 'sub_2', {'in': 1, 'out': 1}),
|
||||
('sub_1', 'add_1', {}),
|
||||
('sub_2', 'add_2', {}),
|
||||
('split', 'concat', {'in': 0, 'out': 0}),
|
||||
('split', 'concat', {'in': 1, 'out': 1}),
|
||||
('add_1', 'concat', {'in': 2, 'out': 0}),
|
||||
('add_2', 'concat', {'in': 3, 'out': 0}),
|
||||
('concat', 'split_2', {}),
|
||||
('split_2', 'squeeze_1', {'in': 0, 'out': 2}),
|
||||
('split_2', 'squeeze_2', {'in': 0, 'out': 3}),
|
||||
('squeeze_1', 'less_1', {}),
|
||||
('squeeze_2', 'less_2', {}),
|
||||
('less_1', 'not_1', {}),
|
||||
('less_2', 'not_2', {}),
|
||||
('not_1', 'cast_11', {}),
|
||||
('cast_11', 'cast_12', {}),
|
||||
('not_2', 'cast_21', {}),
|
||||
('cast_21', 'cast_22', {}),
|
||||
('cast_12', 'and', {}),
|
||||
('cast_22', 'and', {}),
|
||||
('and', 'cast_31', {}),
|
||||
('cast_31', 'cast_32', {}),
|
||||
('cast_32', 'nonzero', {}),
|
||||
('nonzero', 'transpose', {}),
|
||||
('transpose', 'squeeze', {}),
|
||||
('squeeze', 'cast', {}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
source_connection = match['split'].in_port(0).get_connection()
|
||||
source_node = source_connection.get_source().node
|
||||
cast_node = match['cast']
|
||||
|
||||
range_node = Range(graph, {'name': source_node.id + '/Range'}).create_node()
|
||||
start_node = Const(graph, {'name': range_node.id + '/Start', 'value': int64_array(0)}).create_node()
|
||||
|
||||
step_node = Const(graph, {'name': range_node.id + '/Step', 'value': int64_array(1)}).create_node()
|
||||
input_shape_node = Shape(graph, {'name': start_node.id + '/Shape'}).create_node()
|
||||
input_shape_node.in_port(0).connect(source_node.out_port(0))
|
||||
|
||||
limit_node_1D = node_to_get_batch_value(input_shape_node)
|
||||
limit_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
|
||||
{'name': source_node.id + '/batch_0D_value'}, limit_node_1D)
|
||||
|
||||
range_node.in_port(0).connect(start_node.out_port(0))
|
||||
range_node.in_port(1).connect(limit_node.out_port(0))
|
||||
range_node.in_port(2).connect(step_node.out_port(0))
|
||||
cast_node.out_port(0).get_connection().set_source(range_node.out_port(0))
|
||||
|
||||
graph.remove_nodes_from([node.id for node in match.values()])
|
@ -1,66 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.Mish_fusion import MishFusion
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
|
||||
|
||||
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('mish', {'type': 'Mish', 'name': 'final_mul'}),
|
||||
**result('result')
|
||||
}
|
||||
ref_edges = [('input', 'mish'), ('mish', 'result')]
|
||||
|
||||
|
||||
class MishFusionTest(unittest.TestCase):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('softplus', {'op': 'SoftPlus'}),
|
||||
**regular_op('tanh', {'op': 'Tanh'}),
|
||||
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
edges = [('input', 'softplus', {'in': 0, 'out': 0}),
|
||||
('input', 'mul', {'in': 0, 'out': 0}),
|
||||
('softplus', 'tanh', {'in': 0, 'out': 0}),
|
||||
('tanh', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'result', {'in': 0, 'out': 0})]
|
||||
|
||||
def test_mish_fusion(self):
|
||||
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
|
||||
|
||||
graph_ref = build_graph(ref_nodes, ref_edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
MishFusion().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Mish')
|
||||
|
||||
def test_mish_fusion_different_source(self):
|
||||
# check case when different tensors goes to Mul and SoftPlus
|
||||
graph = build_graph_with_edge_attrs({
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('input_2', {'type': 'Parameter'}),
|
||||
**regular_op('softplus', {'op': 'SoftPlus'}),
|
||||
**regular_op('tanh', {'op': 'Tanh'}),
|
||||
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}, [('input', 'softplus', {'in': 0, 'out': 0}),
|
||||
('input_2', 'mul', {'in': 0, 'out': 0}),
|
||||
('softplus', 'tanh', {'in': 0, 'out': 0}),
|
||||
('tanh', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'result', {'in': 0, 'out': 0})], {})
|
||||
|
||||
graph_ref = graph.copy()
|
||||
graph.stage = 'front'
|
||||
|
||||
MishFusion().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
@ -1,57 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.Softplus_fusion import SoftplusFusion
|
||||
from mo.front.common.partial_infer.utils import float_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs
|
||||
|
||||
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_log'}),
|
||||
**result('result')
|
||||
}
|
||||
ref_edges = [('input', 'softplus'), ('softplus', 'result')]
|
||||
|
||||
|
||||
class SoftplusFusionTest(unittest.TestCase):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('exp', {'op': 'Exp'}),
|
||||
**const('const_1', float_array([1.0])),
|
||||
**regular_op('add', {'op': 'Add'}),
|
||||
**regular_op('ln', {'op': 'Log', 'name': 'final_log'}),
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
edges = [('input', 'exp', {'in': 0, 'out': 0}),
|
||||
('const_1', 'add', {'in': 0, 'out': 0}),
|
||||
('exp', 'add', {'in': 1, 'out': 0}),
|
||||
('add', 'ln', {'in': 0, 'out': 0}),
|
||||
('ln', 'result', {'in': 0, 'out': 0})]
|
||||
|
||||
def test_softplus_fusion_test(self):
|
||||
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
|
||||
|
||||
graph_ref = build_graph(ref_nodes, ref_edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
SoftplusFusion().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_log')) == 1 and
|
||||
graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus')
|
||||
|
||||
def test_softplus_fusion_test_wrong_const(self):
|
||||
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}})
|
||||
|
||||
graph_ref = graph.copy()
|
||||
graph.stage = 'front'
|
||||
|
||||
SoftplusFusion().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -1,119 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.Swish_fusion import SwishWithSigmoidWithoutBeta, SwishWithSigmoidWithBeta
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
|
||||
|
||||
ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('swish', {'type': 'Swish', 'name': 'final_mul'}),
|
||||
**result('result')
|
||||
}
|
||||
ref_edges = [('input', 'swish'), ('swish', 'result')]
|
||||
|
||||
|
||||
class SwishWithSigmoidWithoutBetaTest(unittest.TestCase):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('sigmoid', {'op': 'Sigmoid'}),
|
||||
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
edges = [('input', 'mul', {'in': 0, 'out': 0}),
|
||||
('input', 'sigmoid', {'in': 0, 'out': 0}),
|
||||
('sigmoid', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'result', {'in': 0, 'out': 0})]
|
||||
|
||||
def test_swish_with_sigmoid_without_beta_test(self):
|
||||
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
|
||||
|
||||
graph_ref = build_graph(ref_nodes, ref_edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
|
||||
|
||||
def test_swish_with_sigmoid_without_beta_different_tensors(self):
|
||||
graph = build_graph_with_edge_attrs({
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('input_2', {'type': 'Parameter'}),
|
||||
**regular_op('sigmoid', {'op': 'Sigmoid'}),
|
||||
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}, [('input_2', 'mul', {'in': 0, 'out': 0}),
|
||||
('input', 'sigmoid', {'in': 0, 'out': 0}),
|
||||
('sigmoid', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'result', {'in': 0, 'out': 0})], {})
|
||||
|
||||
graph_ref = graph.copy()
|
||||
graph.stage = 'front'
|
||||
|
||||
SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
class SwishWithSigmoidWithBetaTest(unittest.TestCase):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('beta', {'type': 'Parameter'}),
|
||||
**regular_op('mul_beta', {'op': 'Mul'}),
|
||||
**regular_op('sigmoid', {'op': 'Sigmoid'}),
|
||||
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
edges = [('input', 'mul_beta', {'in': 0, 'out': 0}),
|
||||
('input', 'mul_2', {'in': 0, 'out': 0}),
|
||||
('beta', 'mul_beta', {'in': 1, 'out': 0}),
|
||||
('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
|
||||
('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
|
||||
('mul_2', 'result', {'in': 0, 'out': 0})]
|
||||
|
||||
def test_swish_with_sigmoid_with_beta_test(self):
|
||||
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
|
||||
|
||||
new_ref_nodes = ref_nodes.copy()
|
||||
new_ref_nodes.update(**regular_op('beta', {'type': 'Parameter'}))
|
||||
|
||||
graph_ref = build_graph(new_ref_nodes, ref_edges + [('beta', 'swish')])
|
||||
graph.stage = 'front'
|
||||
|
||||
SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
|
||||
graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
|
||||
|
||||
def test_swish_with_sigmoid_with_beta_different_tensors(self):
|
||||
graph = build_graph_with_edge_attrs({
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**regular_op('input_2', {'type': 'Parameter'}),
|
||||
**regular_op('beta', {'type': 'Parameter'}),
|
||||
**regular_op('mul_beta', {'op': 'Mul'}),
|
||||
**regular_op('sigmoid', {'op': 'Sigmoid'}),
|
||||
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
|
||||
**result('result'),
|
||||
}, [('input', 'mul_beta', {'in': 0, 'out': 0}),
|
||||
('input_2', 'mul_2', {'in': 0, 'out': 0}),
|
||||
('beta', 'mul_beta', {'in': 1, 'out': 0}),
|
||||
('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
|
||||
('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
|
||||
('mul_2', 'result', {'in': 0, 'out': 0})], {})
|
||||
|
||||
graph_ref = graph.copy()
|
||||
graph.stage = 'front'
|
||||
|
||||
SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
@ -1,115 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.onnx.quantize_dequantize_linear import QuantizeDequantizeLinear
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
# quantize and dequantize share tensors with scale/zp
|
||||
nodes0_attributes = {
|
||||
'input': {'kind': 'op', 'op': 'AnyOp'},
|
||||
'quantize': {'kind': 'op', 'op': 'QuantizeLinear'},
|
||||
'dequantize': {'kind': 'op', 'op': 'DequantizeLinear'},
|
||||
'scale_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'zerop_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'out': {'kind': 'op', 'op': 'AnyOp'},
|
||||
}
|
||||
|
||||
# quantize and dequantize do not share tensors with scale/zp
|
||||
nodes1_attributes = {
|
||||
'input': {'kind': 'op', 'op': 'AnyOp'},
|
||||
'quantize': {'kind': 'op', 'op': 'QuantizeLinear'},
|
||||
'dequantize': {'kind': 'op', 'op': 'DequantizeLinear'},
|
||||
'scale_param_q': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'zerop_param_q': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'scale_param_dq': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'zerop_param_dq': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'out': {'kind': 'op', 'op': 'AnyOp'},
|
||||
}
|
||||
|
||||
nodes_ref_attributes = {
|
||||
'input': {'kind': 'op', 'op': 'AnyOp'},
|
||||
'fq': {'kind': 'op', 'op': 'FakeQuantize'},
|
||||
'min_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'max_param': {'kind': 'op', 'type': 'Const', 'op': 'Const'},
|
||||
'out': {'kind': 'op', 'op': 'AnyOp'},
|
||||
}
|
||||
|
||||
|
||||
class TestQuantizeDeQuantize2FakeQuantize(unittest.TestCase):
|
||||
|
||||
def test_quantizedequantize2fakequantize_0(self):
|
||||
# testing the code path with uint8 zero-point
|
||||
graph = build_graph(nodes1_attributes,
|
||||
[('input', 'quantize'),
|
||||
('quantize', 'dequantize'),
|
||||
('scale_param_q', 'quantize'),
|
||||
('zerop_param_q', 'quantize'),
|
||||
('scale_param_dq', 'dequantize'),
|
||||
('zerop_param_dq', 'dequantize'),
|
||||
('dequantize', 'out'),
|
||||
],
|
||||
{'scale_param_q': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
|
||||
'zerop_param_q': {'shape': np.array([1]), 'value': np.uint8(0)},
|
||||
'scale_param_dq': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
|
||||
'zerop_param_dq': {'shape': np.array([1]), 'value': np.uint8(0)},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_ref_attributes,
|
||||
[('input', 'fq', {'in': 0}),
|
||||
('min_param', 'fq', {'out': 0, 'in': 1}),
|
||||
('min_param', 'fq', {'out': 0, 'in': 3}),
|
||||
('max_param', 'fq', {'out': 0, 'in': 2}),
|
||||
('max_param', 'fq', {'out': 0, 'in': 4}),
|
||||
('fq', 'out'),
|
||||
],
|
||||
{'fq': {'levels': 256},
|
||||
'min_param': {'value': np.float32(0.0)},
|
||||
'max_param': {'value': np.float32(1.0)},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
tested_class = QuantizeDequantizeLinear()
|
||||
tested_class.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_quantizedequantize2fakequantize_1(self):
|
||||
# testing the code path with int8 zero-point
|
||||
graph = build_graph(nodes0_attributes,
|
||||
[('input', 'quantize'),
|
||||
('quantize', 'dequantize'),
|
||||
('scale_param', 'quantize'),
|
||||
('zerop_param', 'quantize'),
|
||||
('scale_param', 'dequantize'),
|
||||
('zerop_param', 'dequantize'),
|
||||
('dequantize', 'out'),
|
||||
],
|
||||
{'scale_param': {'shape': np.array([1]), 'value': np.float32(1.0 / 255)},
|
||||
'zerop_param': {'shape': np.array([1]), 'value': np.int8(0)},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_ref_attributes,
|
||||
[('input', 'fq', {'in': 0}),
|
||||
('min_param', 'fq', {'out': 0, 'in': 1}),
|
||||
('min_param', 'fq', {'out': 0, 'in': 3}),
|
||||
('max_param', 'fq', {'out': 0, 'in': 2}),
|
||||
('max_param', 'fq', {'out': 0, 'in': 4}),
|
||||
('fq', 'out'),
|
||||
],
|
||||
{'fq': {'levels': 256},
|
||||
'min_param': {'value': np.float32(-128.0 / 255)},
|
||||
'max_param': {'value': np.float32(127.0 / 255)},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
tested_class = QuantizeDequantizeLinear()
|
||||
tested_class.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -45,6 +45,7 @@ namespace ngraph
|
||||
|
||||
const std::unordered_set<std::string>& get_names() const;
|
||||
void set_names(const std::unordered_set<std::string>& names);
|
||||
void add_names(const std::unordered_set<std::string>& names);
|
||||
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
|
||||
void set_element_type(const element::Type& elemenet_type);
|
||||
void set_partial_shape(const PartialShape& partial_shape);
|
||||
|
@ -123,6 +123,14 @@ void descriptor::Tensor::set_names(const std::unordered_set<std::string>& names)
|
||||
m_names = names;
|
||||
}
|
||||
|
||||
void descriptor::Tensor::add_names(const std::unordered_set<std::string>& names)
|
||||
{
|
||||
for (const auto& name : names)
|
||||
{
|
||||
m_names.insert(name);
|
||||
}
|
||||
}
|
||||
|
||||
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
|
||||
{
|
||||
std::string names;
|
||||
|
@ -167,11 +167,9 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
|
||||
// Change I's connected upstream output to O_rep
|
||||
for (size_t i = 0; i < target->get_output_size(); i++)
|
||||
{
|
||||
for (auto& input : target->output(i).get_target_inputs())
|
||||
{
|
||||
input.replace_source_output(replacement->output(output_order[i]));
|
||||
}
|
||||
target->output(i).replace(replacement->output(output_order[i]));
|
||||
}
|
||||
|
||||
replacement->add_node_control_dependents(target);
|
||||
replacement->add_node_control_dependencies(target);
|
||||
target->clear_control_dependents();
|
||||
@ -912,7 +910,15 @@ bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>&
|
||||
replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
// Save replacement tensor names before replacement as they will be
|
||||
// overrided by the output tensor names
|
||||
auto output_names = replacement.get_tensor_ptr()->get_names();
|
||||
output.replace(replacement);
|
||||
|
||||
// Restore back original replacement tensor names
|
||||
replacement.get_tensor().add_names(output_names);
|
||||
|
||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||
replacement.get_node_shared_ptr());
|
||||
return true;
|
||||
|
@ -76,6 +76,7 @@ namespace ngraph
|
||||
{
|
||||
input.replace_source_output(replacement);
|
||||
}
|
||||
replacement.get_tensor_ptr()->set_names(get_tensor_ptr()->get_names());
|
||||
}
|
||||
|
||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
|
@ -57,11 +57,13 @@ void op::util::ScatterNDBase::validate_and_infer_types()
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs_rank.is_dynamic() || indices_rank.is_dynamic() ||
|
||||
indices_shape[indices_rank.get_length() - 1].is_dynamic() ||
|
||||
indices_shape[indices_rank.get_length() - 1].get_length() <=
|
||||
inputs_rank.get_length(),
|
||||
"Last dimension of indices can be at most the rank of inputs");
|
||||
|
||||
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static())
|
||||
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static() &&
|
||||
indices_shape[indices_rank.get_length() - 1].is_static())
|
||||
{
|
||||
auto expected_updates_rank = indices_rank.get_length() + inputs_rank.get_length() -
|
||||
indices_shape[indices_rank.get_length() - 1].get_length() - 1;
|
||||
|
@ -108,3 +108,63 @@ TEST(replace_node, replace_nodes)
|
||||
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(0), x_replacement);
|
||||
ASSERT_EQ(z_replacement->get_input_node_shared_ptr(1), mul);
|
||||
}
|
||||
|
||||
TEST(replace_node, simple_node_replacement)
|
||||
{
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param->output(0).get_tensor().set_names({"a", "b"});
|
||||
auto relu = std::make_shared<op::Relu>(param);
|
||||
relu->output(0).get_tensor().set_names({"c", "d"});
|
||||
|
||||
auto new_relu = std::make_shared<op::Relu>(param);
|
||||
new_relu->output(0).get_tensor().set_names({"f"});
|
||||
replace_node(relu, new_relu);
|
||||
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination)
|
||||
{
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param->output(0).get_tensor().set_names({"a", "b"});
|
||||
auto relu1 = std::make_shared<op::Relu>(param);
|
||||
relu1->output(0).get_tensor().set_names({"c", "d"});
|
||||
auto relu2 = std::make_shared<op::Relu>(relu1);
|
||||
relu2->output(0).get_tensor().set_names({"e", "f"});
|
||||
|
||||
ASSERT_TRUE(replace_output_update_name(relu2->output(0), relu2->input_value(0)));
|
||||
ASSERT_EQ(relu1->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d", "e", "f"}));
|
||||
ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"a", "b"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, output_replacement)
|
||||
{
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param->output(0).get_tensor().set_names({"a", "b"});
|
||||
auto relu = std::make_shared<op::Relu>(param);
|
||||
relu->output(0).get_tensor().set_names({"c", "d"});
|
||||
|
||||
auto new_relu = std::make_shared<op::Relu>(param);
|
||||
new_relu->output(0).get_tensor().set_names({"f"});
|
||||
|
||||
relu->output(0).replace(new_relu->output(0));
|
||||
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, source_replacement)
|
||||
{
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param->output(0).get_tensor().set_names({"a", "b"});
|
||||
|
||||
auto param1 = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param1->output(0).get_tensor().set_names({"c", "d"});
|
||||
|
||||
auto relu = std::make_shared<op::Relu>(param);
|
||||
relu->input(0).replace_source_output(param1->output(0));
|
||||
|
||||
ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"a", "b"}));
|
||||
ASSERT_EQ(param1->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user