[dynamism][CPU] PRelu transformations: dynamic shapes support (#7342)

* [TESTS] CPUUnitTests: unitTestUtils added to CMakeLists

* [CPUTransformations] ReshapePrelu: dynamic shapes support

* [TESTS] ReshapePrelu tests

* [CPUTransformations] ConvertToLeakyRelu: dynamic shapes support

* [TESTS] ConvertToLeakyRelu tests

* Reshape1D: functions moved to namespace

* ReshapePRelu: slope as parameter fix

* postreview fixes

* cleanup

* postreview fixes

* ReshapePRelu quick fix
This commit is contained in:
Vladislav Golubev 2021-10-25 14:12:11 +03:00 committed by GitHub
parent e6bc8c59a9
commit 57a7d3dfcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 430 additions and 21 deletions

View File

@ -12,8 +12,9 @@
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertToLeakyRelu, "ConvertToLeakyRelu", 0);
MKLDNNPlugin::ConvertToLeakyRelu::ConvertToLeakyRelu() {
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())});
auto input = ngraph::pattern::any_input();
auto slope_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ input, slope_constant });
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu>(m.get_match_root());
@ -21,7 +22,7 @@ MKLDNNPlugin::ConvertToLeakyRelu::ConvertToLeakyRelu() {
return false;
}
auto slopeNode = std::dynamic_pointer_cast<ngraph::opset1::Constant>(prelu->get_input_node_shared_ptr(1));
if (slopeNode != nullptr && ngraph::shape_size(prelu->get_input_shape(1)) == 1) {
if (slopeNode != nullptr && ngraph::shape_size(slopeNode->get_shape()) == 1) {
const float slope = slopeNode->cast_vector<float>()[0];
const auto leakyRelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(prelu->input(0).get_source_output(), slope,
prelu->output(0).get_element_type());

View File

@ -12,32 +12,45 @@
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapePRelu, "ReshapePRelu", 0);
MKLDNNPlugin::ReshapePRelu::ReshapePRelu() {
auto prelu = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())});
auto input_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto slope_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto prelu_m = ngraph::pattern::wrap_type<ngraph::opset1::PRelu>({ input_m, slope_m });
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu>(m.get_match_root());
if (!prelu || ngraph::shape_size(prelu->get_input_shape(1)) == 1 || prelu->get_input_shape(1).size() != 1) {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto prelu = pattern_map.at(prelu_m).get_node_shared_ptr();
const auto input = pattern_map.at(input_m);
const auto slope = pattern_map.at(slope_m);
const auto prelu_pshape = prelu->get_input_partial_shape(0);
const auto prelu_rank = prelu_pshape.rank();
const auto slope_pshape = prelu->get_input_partial_shape(1);
const auto slope_rank = slope_pshape.rank();
if (prelu_rank.get_length() == 1 || slope_rank.get_length() != 1) {
return false;
}
const auto prelu_shape = prelu->input_value(0).get_shape();
const auto slope_shape = prelu->input_value(1).get_shape();
ngraph::Shape new_shape(prelu_shape.size(), 1);
const auto slope_dim = slope_shape[0];
const auto channel_dim_idx = prelu_shape.size() > 1 ? 1 : 0;
if (slope_dim != prelu_shape[channel_dim_idx]) {
return false;
}
new_shape[channel_dim_idx] = slope_dim;
auto slope = ngraph::op::util::reshapeTo(prelu->input_value(1), new_shape);
auto new_prelu = std::make_shared<ngraph::opset1::PRelu>(prelu->input(0).get_source_output(), slope);
const auto channel_dim_idx = 1;
if (slope_pshape.is_static()) {
const auto slope_shape = slope_pshape.to_shape();
if (!prelu_pshape[channel_dim_idx].is_dynamic() && slope_shape[0] != prelu_pshape[channel_dim_idx].get_length()) {
return false;
}
}
std::vector<std::int64_t> target_shape(prelu_rank.get_length(), 1);
target_shape[channel_dim_idx] = -1;
const auto target_shape_const = ngraph::opset1::Constant::create(ngraph::element::i64, { target_shape.size() }, target_shape);
auto new_slope = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(slope, target_shape_const, true);
auto new_prelu = prelu->clone_with_new_inputs({ input, new_slope });
ngraph::replace_node(prelu, new_prelu);
new_prelu->set_friendly_name(prelu->get_friendly_name());
ngraph::copy_runtime_info(prelu, new_prelu);
ngraph::replace_node(prelu, new_prelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, "ReshapePRelu");
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu_m, "ReshapePRelu");
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,145 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph_transformations/convert_to_leaky_relu.hpp>
#include <ngraph_transformations/op/leaky_relu.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/type_relaxed.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace MKLDNNPlugin;
TEST(TransformationTests, ConvertToLeakyReluTest1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertToLeakyRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertToLeakyReluTest2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertToLeakyRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertToLeakyReluTest3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertToLeakyRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertToLeakyReluTest4) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
ngraph::element::TypeVector{ ngraph::element::f32 },
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertToLeakyRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
auto prelu = std::make_shared<MKLDNNPlugin::LeakyReluNode>(input, -2.f, ngraph::element::f32);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertToLeakyReluTest5) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
f_ref = f;
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertToLeakyRelu>();
m.run_passes(f);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -0,0 +1,250 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph_transformations/reshape_prelu.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/type_relaxed.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace MKLDNNPlugin;
TEST(TransformationTests, ReshapePReluTest1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input_pshape = ngraph::PartialShape{ ngraph::Dimension::dynamic(), 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic() };
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_pshape);
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input_pshape = ngraph::PartialShape{ ngraph::Dimension::dynamic(), 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic() };
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_pshape);
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest4) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
f_ref = f;
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest5) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
f_ref = f;
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest6) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 4, 4 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 4 }, { -2.f, -1.f, -2.f, -1.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
f_ref = f;
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
ngraph::element::TypeVector{ ngraph::element::f32 },
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::Shape{ 1, 3, 16, 16 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3, 1, 1 }, { -2.f, -1.f, -2.f });
auto relaxed_prelu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::PRelu>>(
ngraph::element::TypeVector{ ngraph::element::f32, ngraph::element::f32 },
ngraph::element::TypeVector{ ngraph::element::f32 },
ngraph::op::TemporaryReplaceOutputType(input, ngraph::element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(slope, ngraph::element::f32).get());
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ relaxed_prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest8) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 4, 3 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 4, 3 });
auto slope = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 3 }, { -2.f, -1.f, -2.f });
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReshapePReluTest9) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto slope = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, slope);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input, slope });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ReshapePRelu>();
m.run_passes(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
auto slope = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto shape_of = std::make_shared<ngraph::opset1::ShapeOf>(slope);
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, -1, 1, 1 });
auto reshape = std::make_shared<ngraph::opset1::Reshape>(slope, reshape_const, true);
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input, reshape);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ prelu }, ngraph::ParameterVector{ input, slope });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}