[dynamism][CPU] PRelu transformations: dynamic shapes support (#7342)
* [TESTS] CPUUnitTests: unitTestUtils added to CMakeLists * [CPUTransformations] ReshapePrelu: dynamic shapes support * [TESTS] ReshapePrelu tests * [CPUTransformations] ConvertToLeakyRelu: dynamic shapes support * [TESTS] ConvertToLeakyRelu tests * Reshape1D: functions moved to namespace * ReshapePRelu: slope as parameter fix * postreview fixes * cleanup * postreview fixes * ReshapePRelu quick fix
This commit is contained in:
parent
e6bc8c59a9
commit
57a7d3dfcf
@ -12,8 +12,9 @@
|
||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertToLeakyRelu, "ConvertToLeakyRelu", 0);
|
||||
|
||||
MKLDNNPlugin::ConvertToLeakyRelu::ConvertToLeakyRelu() {
|
||||
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
|
||||
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())});
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto slope_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ input, slope_constant });
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
||||
auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu>(m.get_match_root());
|
||||
@ -21,7 +22,7 @@ MKLDNNPlugin::ConvertToLeakyRelu::ConvertToLeakyRelu() {
|
||||
return false;
|
||||
}
|
||||
auto slopeNode = std::dynamic_pointer_cast<ngraph::opset1::Constant>(prelu->get_input_node_shared_ptr(1));
|
||||
if (slopeNode != nullptr && ngraph::shape_size(prelu->get_input_shape(1)) == 1) {
|
||||
if (slopeNode != nullptr && ngraph::shape_size(slopeNode->get_shape()) == 1) {
|
||||
const float slope = slopeNode->cast_vector<float>()[0];
|
||||
const auto leakyRelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(prelu->input(0).get_source_output(), slope,
|
||||
prelu->output(0).get_element_type());
|
||||
|
@ -12,32 +12,45 @@
|
||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapePRelu, "ReshapePRelu", 0);
|
||||
|
||||
MKLDNNPlugin::ReshapePRelu::ReshapePRelu() {
|
||||
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
|
||||
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())});
|
||||
auto input_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||
auto slope_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||
auto prelu_m = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ input_m, slope_m });
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
||||
auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu>(m.get_match_root());
|
||||
if (!prelu || ngraph::shape_size(prelu->get_input_shape(1)) == 1 || prelu->get_input_shape(1).size() != 1) {
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto prelu = pattern_map.at(prelu_m).get_node_shared_ptr();
|
||||
const auto input = pattern_map.at(input_m);
|
||||
const auto slope = pattern_map.at(slope_m);
|
||||
|
||||
const auto prelu_pshape = prelu->get_input_partial_shape(0);
|
||||
const auto prelu_rank = prelu_pshape.rank();
|
||||
const auto slope_pshape = prelu->get_input_partial_shape(1);
|
||||
const auto slope_rank = slope_pshape.rank();
|
||||
if (prelu_rank.get_length() == 1 || slope_rank.get_length() != 1) {
|
||||
return false;
|
||||
}
|
||||
const auto prelu_shape = prelu->input_value(0).get_shape();
|
||||
const auto slope_shape = prelu->input_value(1).get_shape();
|
||||
ngraph::Shape new_shape(prelu_shape.size(), 1);
|
||||
const auto slope_dim = slope_shape[0];
|
||||
const auto channel_dim_idx = prelu_shape.size() > 1 ? 1 : 0;
|
||||
if (slope_dim != prelu_shape[channel_dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
new_shape[channel_dim_idx] = slope_dim;
|
||||
|
||||
auto slope = ngraph::op::util::reshapeTo(prelu->input_value(1), new_shape);
|
||||
auto new_prelu = std::make_shared<ngraph::opset1::PRelu>(prelu->input(0).get_source_output(), slope);
|
||||
const auto channel_dim_idx = 1;
|
||||
if (slope_pshape.is_static()) {
|
||||
const auto slope_shape = slope_pshape.to_shape();
|
||||
if (!prelu_pshape[channel_dim_idx].is_dynamic() && slope_shape[0] != prelu_pshape[channel_dim_idx].get_length()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::int64_t> target_shape(prelu_rank.get_length(), 1);
|
||||
target_shape[channel_dim_idx] = -1;
|
||||
const auto target_shape_const = ngraph::opset1::Constant::create(ngraph::element::i64, { target_shape.size() }, target_shape);
|
||||
auto new_slope = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(slope, target_shape_const, true);
|
||||
auto new_prelu = prelu->clone_with_new_inputs({ input, new_slope });
|
||||
|
||||
ngraph::replace_node(prelu, new_prelu);
|
||||
new_prelu->set_friendly_name(prelu->get_friendly_name());
|
||||
ngraph::copy_runtime_info(prelu, new_prelu);
|
||||
ngraph::replace_node(prelu, new_prelu);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, "ReshapePRelu");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu_m, "ReshapePRelu");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
145
inference-engine/tests/unit/cpu/convert_to_leaky_relu_test.cpp
Normal file
145
inference-engine/tests/unit/cpu/convert_to_leaky_relu_test.cpp
Normal file
@ -0,0 +1,145 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph_transformations/convert_to_leaky_relu.hpp>
|
||||
#include <ngraph_transformations/op/leaky_relu.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph_ops/type_relaxed.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace MKLDNNPlugin;
|
||||
|
||||
TEST(TransformationTests, ConvertToLeakyReluTest1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertToLeakyRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertToLeakyReluTest2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertToLeakyRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertToLeakyReluTest3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertToLeakyRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertToLeakyReluTest4) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
|
||||
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
|
||||
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
|
||||
ngraph::element::TypeVector{ ngraph::element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertToLeakyRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertToLeakyReluTest5) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
f_ref = f;
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertToLeakyRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
250
inference-engine/tests/unit/cpu/reshape_prelu_test.cpp
Normal file
250
inference-engine/tests/unit/cpu/reshape_prelu_test.cpp
Normal file
@ -0,0 +1,250 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph_transformations/reshape_prelu.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph_ops/type_relaxed.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace MKLDNNPlugin;
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input_pshape = ngraph::PartialShape{ ngraph::Dimension::dynamic(), 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic() };
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_pshape);
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input_pshape = ngraph::PartialShape{ ngraph::Dimension::dynamic(), 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic() };
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_pshape);
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest4) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
f_ref = f;
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest5) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
f_ref = f;
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest6) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 4, 4 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 4 }, { -2.f, -1.f, -2.f, -1.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
f_ref = f;
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest7) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
|
||||
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
|
||||
ngraph::element::TypeVector{ ngraph::element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
|
||||
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
|
||||
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
|
||||
ngraph::element::TypeVector{ ngraph::element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest8) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 4, 3 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 4, 3 });
|
||||
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3 }, { -2.f, -1.f, -2.f });
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReshapePReluTest9) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto slope = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input, slope });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ReshapePRelu>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
|
||||
auto slope = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset1::ShapeOf>(slope);
|
||||
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, -1, 1, 1 });
|
||||
auto reshape = std::make_shared<ngraph::opset1::Reshape>(slope, reshape_const, true);
|
||||
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, reshape);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input, slope });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user