Add StrideOptimization pass (#5314)

* Add StrideOptimization pass

It's based on model-optimizer/mo/middle/passes/fusing/resnet_optimization.py

* use BackwardGraphRewrite

* code style

* handle other nodes

* address review comments

* rename files

* fix comment

* fix narrowing warning

* add squeeze after pool if its input was reshaped

* address review comments
This commit is contained in:
Mateusz Tabaka
2021-06-28 09:41:36 +02:00
committed by GitHub
parent 9eb80071b0
commit 813531f45d
10 changed files with 786 additions and 3 deletions

View File

@@ -0,0 +1,70 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/util.hpp>
#include <ngraph/pass/pass.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvStridesPropagation;
class TRANSFORMATIONS_API SupportedNodesStridesPropagation;
class TRANSFORMATIONS_API UnsupportedNodesStridesPropagation;
class TRANSFORMATIONS_API StridesOptimization;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvStridesPropagation either propagates stride (greater than 1) from Convolution up through the graph
* or inserts pooling between current node and its consumers if the consumers have different StridesProp attributes.
* Strides can be propagated if Convolution kernel is {1, 1, ...}
*/
class ngraph::pass::ConvStridesPropagation: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvStridesPropagation();
};
/**
* @ingroup ie_transformation_common_api
* @brief SupportedNodesStridesPropagation either propagates stride (greater than 1) from current node up through the graph
* or inserts pooling between current node and its consumers if the consumers have different StridesProp attributes.
*/
class ngraph::pass::SupportedNodesStridesPropagation: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SupportedNodesStridesPropagation();
};
/**
* @ingroup ie_transformation_common_api
* @brief UnsupportedNodesStridesPropagation inserts pooling between current node and its consumers
* if the consumers have different StridesProp attributes.
*/
class ngraph::pass::UnsupportedNodesStridesPropagation: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
UnsupportedNodesStridesPropagation();
};
/**
* @ingroup ie_transformation_common_api
* @brief StridesOptimization transformation works backward on function and propagates strides up through the graph if possible
*/
class ngraph::pass::StridesOptimization: public ngraph::pass::BackwardGraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
StridesOptimization() {
add_matcher<ngraph::pass::ConvStridesPropagation>();
add_matcher<ngraph::pass::SupportedNodesStridesPropagation>();
add_matcher<ngraph::pass::UnsupportedNodesStridesPropagation>();
}
};

View File

@@ -0,0 +1,26 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/variant.hpp>
#include <transformations_visibility.hpp>
namespace ngraph {
template <>
class TRANSFORMATIONS_API VariantWrapper<Strides> : public VariantImpl<Strides> {
public:
static constexpr VariantTypeInfo type_info{"Variant::Strides", 0};
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value) {
}
};
} // namespace ngraph
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
TRANSFORMATIONS_API ngraph::Strides get_strides_prop(const ngraph::Input<ngraph::Node>& node);
TRANSFORMATIONS_API void insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides);

View File

@@ -128,6 +128,9 @@ Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & inpu
}
return output[0];
}
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
} // namespace util
} // namespace op
} // namespace ngraph

View File

@@ -43,6 +43,7 @@
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp"
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
#include "transformations/common_optimizations/strides_optimization.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@@ -185,6 +186,11 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
fq_fusions->add_matcher<ngraph::pass::ReluFakeQuantizeFusion>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
// StridesOptimization should be at the very end
// because we cannot insert any MaxPools since they may prevent
// other optimizations
manager.register_pass<ngraph::pass::StridesOptimization>();
manager.run_passes(f);
// Returning value is false because pass::Manager always apply Validation pass

View File

@@ -0,0 +1,175 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <numeric>
#include "itt.hpp"
#include <transformations/common_optimizations/strides_optimization.hpp>
#include <transformations/rt_info/strides_property.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/variant.hpp>
#include <ngraph/validation_util.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::StridesOptimization, "StridesOptimization", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvStridesPropagation, "ConvStridesPropagation", 0);
static bool can_propagate_conv_stride(const std::shared_ptr<ngraph::Node>& conv) {
const auto& kernel_shape = conv->input_value(1).get_shape();
return std::all_of(kernel_shape.begin() + 2, kernel_shape.end(), [] (size_t s) -> bool { return s == 1; });
}
static std::tuple<ngraph::Strides, bool> check_next_ops(const std::vector<ngraph::Input<ngraph::Node>>& next_ops) {
std::vector<ngraph::Strides> strides;
for (const auto& op : next_ops) {
if (!has_strides_prop(op)) {
return std::make_tuple(ngraph::Strides{}, false);
}
strides.push_back(get_strides_prop(op));
}
bool all_ops_are_valid = std::all_of(strides.begin(), strides.end(),
[&strides] (const ngraph::Strides& s) -> bool {
bool all_ones = std::all_of(s.begin(), s.end(), [] (size_t i) -> bool { return i == 1; });
return s == strides[0] && !all_ones;
});
return std::make_tuple(strides[0], all_ops_are_valid);
}
static void insert_pooling(const ngraph::Output<ngraph::Node>& first, ngraph::Input<ngraph::Node>& second, const ngraph::Strides& strides) {
auto first_node = first.get_node_shared_ptr();
auto rank = first.get_partial_shape().rank();
bool do_reshape = rank.is_static() && static_cast<size_t>(rank.get_length()) < strides.size() + 2;
if (do_reshape) {
size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
auto ones = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{diff}, std::vector<int64_t>(diff, 1));
auto current_shape = std::make_shared<ngraph::opset7::ShapeOf>(first);
std::shared_ptr<ngraph::Node> new_shape = std::make_shared<ngraph::opset7::Concat>(ngraph::OutputVector{ones, current_shape}, 0);
std::shared_ptr<ngraph::Node> constant_new_shape = get_constant_from_source(new_shape);
if (constant_new_shape)
new_shape = constant_new_shape;
first_node = std::make_shared<ngraph::opset7::Reshape>(first_node, new_shape, false);
}
std::shared_ptr<ngraph::Node> new_node = std::make_shared<ngraph::opset7::MaxPool>(first_node, strides, ngraph::Shape{},
ngraph::Shape{}, ngraph::Shape(strides.size(), 1));
if (do_reshape) {
// squeeze dimensions back
size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
std::vector<size_t> axes(diff);
std::iota(axes.begin(), axes.end(), 0);
new_node = std::make_shared<ngraph::opset7::Squeeze>(new_node,
ngraph::opset7::Constant::create(ngraph::element::u64, ngraph::Shape{diff}, axes));
}
std::shared_ptr<ngraph::Node> constant_new_node = get_constant_from_source(new_node);
if (constant_new_node)
new_node = constant_new_node;
second.replace_source_output(new_node);
}
static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node>>&& next_ops) {
for (auto& op : next_ops) {
if (!has_strides_prop(op))
continue;
auto strides = get_strides_prop(op);
bool are_strides_ones = std::all_of(strides.begin(), strides.end(),
[] (size_t s) -> bool { return s == 1; });
if (!are_strides_ones) {
auto conv = dynamic_cast<ngraph::opset7::Convolution*>(op.get_node());
if (conv) {
conv->set_strides(strides);
} else {
insert_pooling(op.get_source_output(), op, strides);
}
}
}
}
ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
MATCHER_SCOPE(ConvStridesPropagation);
auto data = pattern::any_input();
auto weights = pattern::any_input(pattern::has_static_shape());
auto conv_pattern = pattern::wrap_type<opset7::Convolution>({data, weights});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto conv = std::dynamic_pointer_cast<opset7::Convolution>(m.get_match_root());
if (!conv)
return false;
auto conv_strides = conv->get_strides();
Strides strides_ones(conv_strides.size(), 1);
auto next_ops = op::util::get_node_target_inputs(conv);
bool all_ops_are_valid;
Strides strides;
std::tie(strides, all_ops_are_valid) = check_next_ops(next_ops);
if (!all_ops_are_valid) {
handle_not_equal_stride_props(std::move(next_ops));
} else {
std::transform(conv_strides.begin(), conv_strides.end(), strides.begin(), conv_strides.begin(),
[] (size_t s1, size_t s2) -> size_t { return s1 * s2; });
}
if (can_propagate_conv_stride(conv)) {
conv->set_strides(strides_ones);
auto conv_input = conv->input(0);
insert_strides_prop(conv_input, conv_strides);
} else {
conv->set_strides(conv_strides);
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(conv_pattern, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::SupportedNodesStridesPropagation, "SupportedNodesStridesPropagation", 0);
ngraph::pass::SupportedNodesStridesPropagation::SupportedNodesStridesPropagation() {
MATCHER_SCOPE(SupportedNodesStridesPropagation);
auto root = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, op::util::BinaryElementwiseArithmetic>();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
auto next_ops = op::util::get_node_target_inputs(node);
bool all_ops_are_valid;
Strides strides;
std::tie(strides, all_ops_are_valid) = check_next_ops(next_ops);
if (!all_ops_are_valid) {
return false;
}
for (auto& input : node->inputs()) {
insert_strides_prop(input, strides);
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(root, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::UnsupportedNodesStridesPropagation, "UnsupportedNodesStridesPropagation", 0);
ngraph::pass::UnsupportedNodesStridesPropagation::UnsupportedNodesStridesPropagation() {
MATCHER_SCOPE(UnsupportedNodesStridesPropagation);
auto root = pattern::any_input();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
auto next_ops = op::util::get_node_target_inputs(node);
handle_not_equal_stride_props(std::move(next_ops));
return true;
};
auto m = std::make_shared<pattern::Matcher>(root, matcher_name);
this->register_matcher(m, callback);
}

View File

@@ -0,0 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/rt_info/strides_property.hpp"
constexpr ngraph::VariantTypeInfo ngraph::VariantWrapper<ngraph::Strides>::type_info;
bool has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info();
auto it = rt_map.find(ngraph::VariantWrapper<ngraph::Strides>::type_info.name);
return it != rt_map.end();
}
ngraph::Strides get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info();
const auto& var = rt_map.at(ngraph::VariantWrapper<ngraph::Strides>::type_info.name);
return ngraph::as_type_ptr<ngraph::VariantWrapper<ngraph::Strides>>(var)->get();
}
void insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
auto& rt_map = node.get_rt_info();
rt_map[ngraph::VariantWrapper<ngraph::Strides>::type_info.name] = std::make_shared<ngraph::VariantWrapper<ngraph::Strides>>(strides);
}

View File

@@ -142,6 +142,16 @@ std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<Node>& node, const Ou
return try_fold_unary_output(unary_output_node);
}
std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node) {
std::vector<Input<Node>> result;
for (auto output : node->outputs()) {
for (auto input : output.get_target_inputs()) {
result.push_back(input);
}
}
return result;
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@@ -0,0 +1,460 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/strides_optimization.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
// Tests are based on model-optimizer/mo/middle/passes/fusing/resnet_optimization_test.py
// In description of unit tests below will be used next syntax: Operation(NxM,XxY), where NxM - kernel size, XxY - stride
// Pl->Conv(1x1,1x1)->Conv(1x1,2x2) => Pl->Conv(1x1,2x2)->Conv(1x1,1x1)
TEST(TransformationTests, StridesOptimization1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl->Conv(3x3,2x2)->Conv(1x1,2x2) => Pl->Conv(3x3,4x4)->Conv(1x1,1x1)
TEST(TransformationTests, StridesOptimization2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl->Conv(3x3,2x2)->Conv(3x3,2x2) => Same
TEST(TransformationTests, StridesOptimization3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl--->Conv(3x3,2x2)->ReLU--->Eltwise-->Conv(1x1,2x2) => Pl--->Conv(3x3,4x4)->ReLU--->Eltwise-->Conv(1x1,1x1)
// `-->Conv(3x3,2x2)->ReLU---` `-->Conv(3x3,4x4)->ReLU---`
TEST(TransformationTests, StridesOptimization4) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(data, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_2 = std::make_shared<ngraph::opset7::Relu>(conv_2);
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, relu_2);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add, weights_3, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(data, weights_2, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_2 = std::make_shared<ngraph::opset7::Relu>(conv_2);
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, relu_2);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add, weights_3, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl--->Conv(1x1,1x1)->ReLU--->Eltwise-->Conv(1x1,2x2) => Pl--->Conv(1x1,2x2)->ReLU--->Eltwise-->Conv(1x1,1x1)
// `----------------->ReLU---` `-->Pool(1x1,2x2)->ReLU---`
TEST(TransformationTests, StridesOptimization5) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
auto relu_2 = std::make_shared<ngraph::opset7::Relu>(data);
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, relu_2);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
auto pool = std::make_shared<ngraph::opset7::MaxPool>(data, ngraph::Strides{2, 2}, ngraph::Shape{0, 0}, ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
auto relu_2 = std::make_shared<ngraph::opset7::Relu>(pool);
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, relu_2);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add, weights_2, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl->Conv(1x1,1x1)->Conv(1x1,2x2)->Conv(3x3,1x1)->Conv(1x1,2x2)
// =>
// Pl->Conv(1x1,2x2)->Conv(1x1,1x1)->Conv(3x3,2x2)->Conv(1x1,1x1)
TEST(TransformationTests, StridesOptimization6) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(conv_2, weights_3, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_4 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_4 = std::make_shared<ngraph::opset7::Convolution>(conv_3, weights_4, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_4}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 3, 3}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(conv_2, weights_3, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_4 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_4 = std::make_shared<ngraph::opset7::Convolution>(conv_3, weights_4, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_4}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl->Conv(1x1,1x1) --> Conv(1x1,2x2) --> Conv(1x1,2x2)
// `--> Relu --> Conv(1x1,2x2)
// =>
// Pl->Conv(1x1,1x1) ---> Conv(1x1,4x4) --> Conv(1x1,1x1)
// `--> Pool(1x1, 2x2) -> Relu --> Conv(1x1,1x1)
TEST(TransformationTests, StridesOptimization7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(conv_2, weights_3, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu = std::make_shared<ngraph::opset7::Relu>(conv_1);
auto weights_4 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_4 = std::make_shared<ngraph::opset7::Convolution>(relu, weights_4, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3, conv_4}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(conv_1, weights_2, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(conv_2, weights_3, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto pool = std::make_shared<ngraph::opset7::MaxPool>(conv_1, ngraph::Strides{2, 2}, ngraph::Shape{0, 0}, ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
auto relu = std::make_shared<ngraph::opset7::Relu>(pool);
auto weights_4 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_4 = std::make_shared<ngraph::opset7::Convolution>(relu, weights_4, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3, conv_4}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl--->Conv(1x1,1x1)->ReLU--->Eltwise-->Conv(1x1,2x2)-->Eltwise-->Conv(1x1, 2x2)
// Const---` Pl---`
// =>
// Pl----->Conv(1x1,1x4)----->ReLU---->Eltwise------>Conv(1x1,1x1)------>Eltwise---->Conv(1x1, 1x1)
// Const-->MaxPool(1x1,4x4)-->Squeeze` Pl--->MaxPool(1x1,2x2)-->Squeeze`
TEST(TransformationTests, StridesOptimization8) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
ngraph::Shape const_shape{1, 3, 224, 224};
auto constant = ngraph::opset7::Constant::create(ngraph::element::f32, const_shape, std::vector<float>(shape_size(const_shape), 1));
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, constant);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto data_2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 112, 112});
auto add_2 = std::make_shared<ngraph::opset7::Add>(conv_2, data_2);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add_2, weights_3, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data, data_2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto relu_1 = std::make_shared<ngraph::opset7::Relu>(conv_1);
ngraph::Shape const_shape{1, 3, 56, 56};
auto constant = ngraph::opset7::Constant::create(ngraph::element::f32, const_shape, std::vector<float>(shape_size(const_shape), 1));
auto add = std::make_shared<ngraph::opset7::Add>(relu_1, constant);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add, weights_2, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto data_2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 112, 112});
auto reshape = std::make_shared<ngraph::opset7::Reshape>(data_2,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 3, 112, 112}), false);
auto pool_2 = std::make_shared<ngraph::opset7::MaxPool>(reshape, ngraph::Strides{2, 2}, ngraph::Shape{0, 0},
ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
auto squeeze = std::make_shared<ngraph::opset7::Squeeze>(pool_2,
ngraph::op::Constant::create(ngraph::element::u64, ngraph::Shape{1}, {0}));
auto add_2 = std::make_shared<ngraph::opset7::Add>(conv_2, squeeze);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add_2, weights_3, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data, data_2});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
// Pl------->Conv(1x1,1x1)------>Eltwise------>Conv(1x1,2x2)---->Eltwise-->Conv(1x1, 2x2)
// Pl----->Eltwise---->Eltwise--` Pl--->Eltwise------>Eltwise--`
// Const--` Const-` Const--` Const-`
// =>
// Pl------->Conv(1x1,4x4)------->Eltwise---->Conv(1x1,1x1)-->Eltwise-->Conv(1x1, 1x1)
// Pl----->Eltwise----->Eltwise--` Eltwise------>Eltwise--`
// Const--` Const-` Const--` Const-`
TEST(TransformationTests, StridesOptimization9) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{});
auto data_2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{224});
auto add_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{224}, {128});
auto add = std::make_shared<ngraph::opset7::Add>(data_2, add_const);
auto add_2_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{224}, {128});
auto add_2 = std::make_shared<ngraph::opset7::Add>(add, add_2_const);
auto add_3 = std::make_shared<ngraph::opset7::Add>(conv_1, add_2);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add_3, weights_2, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto data_3 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto add_4_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {128});
auto add_4 = std::make_shared<ngraph::opset7::Add>(data_3, add_4_const);
auto add_5_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {128});
auto add_5 = std::make_shared<ngraph::opset7::Add>(add_4, add_5_const);
auto add_6 = std::make_shared<ngraph::opset7::Add>(conv_2, add_5);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add_6, weights_3, ngraph::Strides{2, 2},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data, data_2, data_3});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::StridesOptimization>();
m.run_passes(f);
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 224, 224});
auto weights_1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_1 = std::make_shared<ngraph::opset7::Convolution>(data, weights_1, ngraph::Strides{4, 4},
ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{});
auto data_2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{224});
auto reshape = std::make_shared<ngraph::opset7::Reshape>(data_2,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 224}), false);
auto pool = std::make_shared<ngraph::opset7::MaxPool>(reshape, ngraph::Strides{4, 4}, ngraph::Shape{0, 0},
ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
auto squeeze = std::make_shared<ngraph::opset7::Squeeze>(pool,
ngraph::op::Constant::create(ngraph::element::u64, ngraph::Shape{3}, {0, 1, 2}));
auto add_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{56}, {128});
auto add = std::make_shared<ngraph::opset7::Add>(squeeze, add_const);
auto add_2_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{56}, {128});
auto add_2 = std::make_shared<ngraph::opset7::Add>(add, add_2_const);
auto add_3 = std::make_shared<ngraph::opset7::Add>(conv_1, add_2);
auto weights_2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_2 = std::make_shared<ngraph::opset7::Convolution>(add_3, weights_2, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
auto data_3 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto new_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
auto reshape_2 = std::make_shared<ngraph::opset7::Reshape>(data_3, new_shape, false);
auto pool_2 = std::make_shared<ngraph::opset7::MaxPool>(reshape_2, ngraph::Strides{2, 2}, ngraph::Shape{0, 0},
ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
auto squeeze_2 = std::make_shared<ngraph::opset7::Squeeze>(pool_2,
ngraph::op::Constant::create(ngraph::element::u64, ngraph::Shape{4}, {0, 1, 2, 3}));
auto add_4_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {128});
auto add_4 = std::make_shared<ngraph::opset7::Add>(squeeze_2, add_4_const);
auto add_5_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {128});
auto add_5 = std::make_shared<ngraph::opset7::Add>(add_4, add_5_const);
auto add_6 = std::make_shared<ngraph::opset7::Add>(conv_2, add_5);
auto weights_3 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {128});
auto conv_3 = std::make_shared<ngraph::opset7::Convolution>(add_6, weights_3, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{}, ngraph::CoordinateDiff{}, ngraph::Strides{});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data, data_2, data_3});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}

View File

@@ -4,6 +4,7 @@
#pragma once
#include <map>
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
@@ -12,6 +13,8 @@ namespace ngraph
{
class Node;
class Variant;
namespace descriptor
{
class Output;
@@ -51,6 +54,11 @@ namespace ngraph
/// \return the tensor of the connected output
Tensor& get_tensor();
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& get_rt_info() { return m_rt_info; }
const RTMap& get_rt_info() const { return m_rt_info; }
/// \brief Replace the current output that supplies a value for this input with output i
/// of node
void replace_output(std::shared_ptr<Node> node, size_t i);
@@ -98,6 +106,7 @@ namespace ngraph
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
Output* m_output;
RTMap m_rt_info;
private:
bool m_is_relevant_to_shape;

View File

@@ -84,16 +84,16 @@ namespace ngraph
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& Input<Node>::get_rt_info() { return m_node->m_outputs.at(m_index).get_rt_info(); }
RTMap& Input<Node>::get_rt_info() { return m_node->m_inputs.at(m_index).get_rt_info(); }
const RTMap& Input<Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
return m_node->m_inputs.at(m_index).get_rt_info();
}
const RTMap& Input<const Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
return m_node->m_inputs.at(m_index).get_rt_info();
}
const Node* Input<const Node>::get_node() const { return m_node; }