Fix wrong statement in ConvertConvolutions transformation (#3056)
* Fixed wrong statement in ConvertConvolution transformation * Added tests * FIXME: 42956
This commit is contained in:
parent
eb82adeb3a
commit
252b99dc01
@ -59,17 +59,19 @@ ngraph::pass::ConvertGroupConvolution::ConvertGroupConvolution() {
|
|||||||
|
|
||||||
// Merge weights layout GOIYX to (G*O)IYX
|
// Merge weights layout GOIYX to (G*O)IYX
|
||||||
auto shape = gconv->input_value(1).get_shape();
|
auto shape = gconv->input_value(1).get_shape();
|
||||||
std::vector<int64_t> reshape_shape{static_cast<int64_t>(shape[0] * shape[1])};
|
Shape reshape_shape{static_cast<size_t>(shape[0] * shape[1])};
|
||||||
for (size_t i = 2; i < shape.size(); ++i) {
|
for (size_t i = 2; i < shape.size(); ++i) {
|
||||||
reshape_shape.push_back(shape[i]);
|
reshape_shape.push_back(shape[i]);
|
||||||
}
|
}
|
||||||
Output<Node> weights;
|
Output<Node> weights;
|
||||||
auto w_input = gconv->input_value(1).get_node_shared_ptr();
|
auto w_input = gconv->input_value(1).get_node_shared_ptr();
|
||||||
if (std::dynamic_pointer_cast<opset1::Reshape>(w_input) && w_input->input_value(0).get_shape().size() == w_input->get_output_shape(0).size() - 1) {
|
if (std::dynamic_pointer_cast<opset1::Reshape>(w_input) && w_input->input_value(0).get_shape() == reshape_shape) {
|
||||||
weights = w_input->input_value(0);
|
weights = w_input->input_value(0);
|
||||||
} else {
|
} else {
|
||||||
weights = std::make_shared<ngraph::opset1::Reshape>(gconv->input_value(1),
|
weights = std::make_shared<ngraph::opset1::Reshape>(gconv->input_value(1),
|
||||||
op::Constant::create(element::i64, Shape{reshape_shape.size()}, reshape_shape), true);
|
op::Constant::create(element::i64, Shape{reshape_shape.size()}, reshape_shape), true);
|
||||||
|
// FIXME: 42956
|
||||||
|
// ngraph::copy_runtime_info(gconv, weights.get_node_shared_ptr());
|
||||||
}
|
}
|
||||||
auto conv_ie = std::make_shared<ngraph::op::ConvolutionIE>(gconv->input_value(0),
|
auto conv_ie = std::make_shared<ngraph::op::ConvolutionIE>(gconv->input_value(0),
|
||||||
weights,
|
weights,
|
||||||
|
@ -20,18 +20,21 @@
|
|||||||
#include <ngraph/pass/visualize_tree.hpp>
|
#include <ngraph/pass/visualize_tree.hpp>
|
||||||
#include <transformations/op_conversions/convert_convolutions.hpp>
|
#include <transformations/op_conversions/convert_convolutions.hpp>
|
||||||
#include <ngraph_ops/convolution_ie.hpp>
|
#include <ngraph_ops/convolution_ie.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
using namespace testing;
|
using namespace testing;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace ngraph::opset1;
|
||||||
|
|
||||||
using InputShape = ngraph::PartialShape;
|
using InputShape = PartialShape;
|
||||||
using WeightsShape = ngraph::Shape;
|
using WeightsShape = PartialShape;
|
||||||
|
|
||||||
class ConvertConvolutionTest: public CommonTestUtils::TestsCommon,
|
class ConvertConvolutionTest: public CommonTestUtils::TestsCommon,
|
||||||
public testing::WithParamInterface<std::tuple<InputShape, WeightsShape> > {
|
public testing::WithParamInterface<std::tuple<InputShape, WeightsShape> > {
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
std::shared_ptr<Function> f, f_ref;
|
||||||
|
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
const auto& input_shape = std::get<0>(GetParam());
|
const auto& input_shape = std::get<0>(GetParam());
|
||||||
@ -42,33 +45,35 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
|
std::shared_ptr<Function> get_initial_function(const PartialShape & input_shape,
|
||||||
const ngraph::Shape & weights_shape) {
|
const PartialShape & weights_shape) {
|
||||||
|
assert(weights_shape.is_static());
|
||||||
auto spatial_dims = input_shape.rank().get_length() - 2;
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, weights_shape, {1});
|
auto weights = Constant::create(element::f32, weights_shape.to_shape(), {1});
|
||||||
auto conv = std::make_shared<ngraph::opset1::Convolution>(input, weights, ngraph::Strides(spatial_dims, 1),
|
auto conv = std::make_shared<Convolution>(input, weights, Strides(spatial_dims, 1),
|
||||||
ngraph::CoordinateDiff(spatial_dims, 0), ngraph::CoordinateDiff(spatial_dims, 0), ngraph::Strides(spatial_dims, 1));
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), Strides(spatial_dims, 1));
|
||||||
|
|
||||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{conv}, ngraph::ParameterVector{input});
|
return std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
|
std::shared_ptr<Function> get_reference_function(const PartialShape & input_shape,
|
||||||
const ngraph::Shape & weights_shape) {
|
const PartialShape & weights_shape) {
|
||||||
|
assert(weights_shape.is_static());
|
||||||
auto spatial_dims = input_shape.rank().get_length() - 2;
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, weights_shape, {1});
|
auto weights = Constant::create(element::f32, weights_shape.get_shape(), {1});
|
||||||
auto conv = std::make_shared<ngraph::op::ConvolutionIE>(input, weights, ngraph::Strides(spatial_dims, 1), ngraph::Strides(spatial_dims, 1),
|
auto conv = std::make_shared<op::ConvolutionIE>(input, weights, Strides(spatial_dims, 1), Strides(spatial_dims, 1),
|
||||||
ngraph::CoordinateDiff(spatial_dims, 0), ngraph::CoordinateDiff(spatial_dims, 0), ngraph::element::f32);
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), element::f32);
|
||||||
|
|
||||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{conv}, ngraph::ParameterVector{input});
|
return std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(ConvertConvolutionTest, CompareFunctions) {
|
TEST_P(ConvertConvolutionTest, CompareFunctions) {
|
||||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
pass::InitNodeInfo().run_on_function(f);
|
||||||
ngraph::pass::ConvertConvolutions().run_on_function(f);
|
pass::ConvertConvolutions().run_on_function(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
@ -92,3 +97,86 @@ INSTANTIATE_TEST_CASE_P(ConvertConvolution, ConvertConvolutionTest,
|
|||||||
std::make_tuple(InputShape{DYN, 3, 10}, WeightsShape{3, 3, 1}),
|
std::make_tuple(InputShape{DYN, 3, 10}, WeightsShape{3, 3, 1}),
|
||||||
std::make_tuple(InputShape{2, DYN, 9}, WeightsShape{2, 3, 2}),
|
std::make_tuple(InputShape{2, DYN, 9}, WeightsShape{2, 3, 2}),
|
||||||
std::make_tuple(InputShape{3, 3, DYN}, WeightsShape{1, 3, 3})));
|
std::make_tuple(InputShape{3, 3, DYN}, WeightsShape{1, 3, 3})));
|
||||||
|
|
||||||
|
TEST(ConvertConvolutionTest, GroupConvolutionWithReshape) {
|
||||||
|
PartialShape input_shape{1, 6, 64, 64};
|
||||||
|
PartialShape weights_shape_before{2 * 3, 3, 5, 5};
|
||||||
|
PartialShape weights_shape_after{2, 3, 3, 5, 5};
|
||||||
|
size_t group = 2;
|
||||||
|
|
||||||
|
std::shared_ptr<Function> f, f_ref;
|
||||||
|
{
|
||||||
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
|
auto weights = std::make_shared<Parameter>(element::f32, weights_shape_before);
|
||||||
|
auto reshape = std::make_shared<Reshape>(weights, Constant::create(element::i64,
|
||||||
|
Shape{static_cast<size_t>(weights_shape_after.rank().get_length())}, weights_shape_after.to_shape()), true);
|
||||||
|
auto conv = std::make_shared<GroupConvolution>(input, reshape, Strides(spatial_dims, 1),
|
||||||
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), Strides(spatial_dims, 1));
|
||||||
|
|
||||||
|
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input, weights});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
|
auto weights = std::make_shared<Parameter>(element::f32, weights_shape_before);
|
||||||
|
auto conv = std::make_shared<op::ConvolutionIE>(input, weights, Strides(spatial_dims, 1), Strides(spatial_dims, 1),
|
||||||
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), element::f32, group);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input, weights});
|
||||||
|
}
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<pass::ConvertConvolutions>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(ConvertConvolutionTest, GroupConvolutionWithReshapeNeg) {
|
||||||
|
PartialShape input_shape{1, 6, 64, 64};
|
||||||
|
PartialShape weights_shape_before{3, 2, 3, 5, 5};
|
||||||
|
PartialShape weights_shape_after{2, 3, 3, 5, 5};
|
||||||
|
PartialShape weights_shape_ref{2 * 3, 3, 5, 5};
|
||||||
|
size_t group = 2;
|
||||||
|
|
||||||
|
std::shared_ptr<Function> f, f_ref;
|
||||||
|
{
|
||||||
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
|
auto weights = std::make_shared<Parameter>(element::f32, weights_shape_before);
|
||||||
|
auto reshape = std::make_shared<Reshape>(weights, Constant::create(element::i64,
|
||||||
|
Shape{static_cast<size_t>(weights_shape_after.rank().get_length())}, weights_shape_after.to_shape()), true);
|
||||||
|
auto conv = std::make_shared<GroupConvolution>(input, reshape, Strides(spatial_dims, 1),
|
||||||
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), Strides(spatial_dims, 1));
|
||||||
|
|
||||||
|
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input, weights});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||||
|
auto input = std::make_shared<Parameter>(element::f32, input_shape);
|
||||||
|
auto weights_param = std::make_shared<Parameter>(element::f32, weights_shape_before);
|
||||||
|
auto reshape = std::make_shared<Reshape>(weights_param, Constant::create(element::i64,
|
||||||
|
Shape{static_cast<size_t>(weights_shape_after.rank().get_length())}, weights_shape_after.to_shape()), true);
|
||||||
|
auto weights = std::make_shared<Reshape>(reshape, Constant::create(element::i64,
|
||||||
|
Shape{static_cast<size_t>(weights_shape_ref.rank().get_length())}, weights_shape_ref.to_shape()), true);
|
||||||
|
auto conv = std::make_shared<op::ConvolutionIE>(input, weights, Strides(spatial_dims, 1), Strides(spatial_dims, 1),
|
||||||
|
CoordinateDiff(spatial_dims, 0), CoordinateDiff(spatial_dims, 0), element::f32, group);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input, weights_param});
|
||||||
|
}
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<pass::ConvertConvolutions>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
// FIXME: 42956
|
||||||
|
// ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user