Updated DeconvolutionIE to support dynamic shapes (#671)
* Updated DeconvolutionIE to support dynamic shapes * Updated DeconvolutionIE to support output_shape input * Updated ConvertConvolutions pass
This commit is contained in:
parent
cbe45b7d0a
commit
638c7b891c
@ -26,31 +26,30 @@ public:
|
||||
DeconvolutionIE(const Output<Node>& data,
|
||||
const Output<Node>& filters,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const Strides& dilations,
|
||||
const Shape& output_shape,
|
||||
const size_t& group = 1,
|
||||
const PadType& auto_pad = PadType::EXPLICIT);
|
||||
const PadType& auto_pad = PadType::EXPLICIT,
|
||||
const CoordinateDiff& output_padding = {},
|
||||
const std::shared_ptr<Node> & output_shape = nullptr);
|
||||
|
||||
DeconvolutionIE(const Output<Node>& data,
|
||||
const Output<Node>& filters,
|
||||
const Output<Node>& bias,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const Strides& dilations,
|
||||
const Shape& output_shape,
|
||||
const size_t& group = 1,
|
||||
const PadType& auto_pad = PadType::EXPLICIT);
|
||||
const PadType& auto_pad = PadType::EXPLICIT,
|
||||
const CoordinateDiff& output_padding = {},
|
||||
const std::shared_ptr<Node> & output_shape = nullptr);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;
|
||||
|
||||
/// \return The data batch shape.
|
||||
const PartialShape get_output_shape() { return m_output_shape; }
|
||||
void set_output_shape(const Shape& output_shape) { m_output_shape = output_shape; }
|
||||
/// \return The strides from the forward prop.
|
||||
const Strides& get_strides() const { return m_strides; }
|
||||
void set_strides(const Strides& strides) { m_strides = strides; }
|
||||
@ -75,9 +74,10 @@ protected:
|
||||
Strides m_dilations;
|
||||
CoordinateDiff m_pads_begin;
|
||||
CoordinateDiff m_pads_end;
|
||||
CoordinateDiff m_output_padding;
|
||||
PadType m_auto_pad;
|
||||
Shape m_output_shape;
|
||||
size_t m_group;
|
||||
std::shared_ptr<Node> m_output_shape;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -5,11 +5,13 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <ngraph/ops.hpp>
|
||||
|
||||
#include "ngraph_ops/deconvolution_ie.hpp"
|
||||
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -19,20 +21,22 @@ constexpr NodeTypeInfo op::DeconvolutionIE::type_info;
|
||||
op::DeconvolutionIE::DeconvolutionIE(const Output<Node>& data,
|
||||
const Output<Node>& filters,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const Strides& dilations,
|
||||
const Shape& output_shape,
|
||||
const size_t& group,
|
||||
const PadType& auto_pad)
|
||||
const PadType& auto_pad,
|
||||
const CoordinateDiff& output_padding,
|
||||
const std::shared_ptr<Node> & output_shape)
|
||||
: Op({data, filters})
|
||||
, m_strides(strides)
|
||||
, m_dilations(dilations)
|
||||
, m_pads_begin(pads_begin)
|
||||
, m_pads_end(pads_end)
|
||||
, m_auto_pad(auto_pad)
|
||||
, m_output_shape(output_shape)
|
||||
, m_group(group) {
|
||||
, m_group(group)
|
||||
, m_output_padding(output_padding)
|
||||
, m_output_shape(output_shape) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -40,48 +44,75 @@ op::DeconvolutionIE::DeconvolutionIE(const Output<Node>& data,
|
||||
const Output<Node>& filters,
|
||||
const Output<Node>& bias,
|
||||
const Strides& strides,
|
||||
const Strides& dilations,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const Strides& dilations,
|
||||
const Shape& output_shape,
|
||||
const size_t& group,
|
||||
const PadType& auto_pad)
|
||||
const PadType& auto_pad,
|
||||
const CoordinateDiff& output_padding,
|
||||
const std::shared_ptr<Node> & output_shape)
|
||||
: Op({data, filters, bias})
|
||||
, m_strides(strides)
|
||||
, m_dilations(dilations)
|
||||
, m_pads_begin(pads_begin)
|
||||
, m_pads_end(pads_end)
|
||||
, m_auto_pad(auto_pad)
|
||||
, m_output_shape(output_shape)
|
||||
, m_group(group) {
|
||||
, m_group(group)
|
||||
, m_output_padding(output_padding)
|
||||
, m_output_shape(output_shape) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::DeconvolutionIE::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), m_output_shape);
|
||||
// To calculate output shape we use opset1::GroupConvolutionBackPropData
|
||||
// but before we need to reshape weights from I(G*O)YX to GIOYX
|
||||
auto weights = input_value(1);
|
||||
const auto weights_pshape = weights.get_partial_shape();
|
||||
const auto weights_shape_rank = weights_pshape.rank();
|
||||
if (weights_pshape.is_static()) {
|
||||
auto weights_shape = weights_pshape.to_shape();
|
||||
std::vector<int64_t> reshape_dims(3);
|
||||
reshape_dims[0] = m_group; // G
|
||||
reshape_dims[1] = weights_shape[0]; // I
|
||||
reshape_dims[2] = weights_shape[1] / m_group; // O
|
||||
reshape_dims.insert(reshape_dims.end(), weights_shape.begin() + 2, weights_shape.end());
|
||||
weights = std::make_shared<opset1::Reshape>(weights, opset1::Constant::create(element::i64, Shape{reshape_dims.size()}, reshape_dims), true);
|
||||
}
|
||||
Output<Node> conv;
|
||||
if (m_output_shape) {
|
||||
conv = std::make_shared<opset1::GroupConvolutionBackpropData>(input_value(0), weights, m_output_shape,
|
||||
m_strides, m_pads_begin, m_pads_end, m_dilations, m_auto_pad, m_output_padding);
|
||||
} else {
|
||||
conv = std::make_shared<opset1::GroupConvolutionBackpropData>(input_value(0), weights,
|
||||
m_strides, m_pads_begin, m_pads_end, m_dilations, m_auto_pad, m_output_padding);
|
||||
}
|
||||
set_output_type(0, conv.get_element_type(), conv.get_partial_shape());
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::DeconvolutionIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::DeconvolutionIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
|
||||
if (new_args.size() == 2) {
|
||||
return make_shared<DeconvolutionIE>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
m_strides,
|
||||
m_pads_begin,
|
||||
m_pads_end,
|
||||
m_dilations,
|
||||
m_output_shape,
|
||||
m_group,
|
||||
m_auto_pad);
|
||||
} else {
|
||||
new_args.at(1),
|
||||
m_strides,
|
||||
m_dilations,
|
||||
m_pads_begin,
|
||||
m_pads_end,
|
||||
m_group,
|
||||
m_auto_pad,
|
||||
m_output_padding,
|
||||
m_output_shape);
|
||||
} else if (new_args.size() == 3) {
|
||||
return make_shared<DeconvolutionIE>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
m_strides,
|
||||
m_pads_begin,
|
||||
m_pads_end,
|
||||
m_dilations,
|
||||
m_output_shape,
|
||||
m_group,
|
||||
m_auto_pad);
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
m_strides,
|
||||
m_dilations,
|
||||
m_pads_begin,
|
||||
m_pads_end,
|
||||
m_group,
|
||||
m_auto_pad,
|
||||
m_output_padding,
|
||||
m_output_shape);
|
||||
}
|
||||
throw ngraph::ngraph_error("Unexpected number of arguments");
|
||||
}
|
||||
|
@ -14,14 +14,8 @@
|
||||
#include <ngraph_ops/deconvolution_ie.hpp>
|
||||
|
||||
void ngraph::pass::ConvertConvolutions::convert_convolution() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 12, 12});
|
||||
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 1, 1});
|
||||
auto conv = std::make_shared<ngraph::opset1::Convolution>(data,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
|
||||
pattern::has_class<opset1::Convolution>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
auto conv = std::dynamic_pointer_cast<ngraph::opset1::Convolution> (m.get_match_root());
|
||||
@ -48,14 +42,8 @@ void ngraph::pass::ConvertConvolutions::convert_convolution() {
|
||||
}
|
||||
|
||||
void ngraph::pass::ConvertConvolutions::convert_group_convolution() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 12, 12});
|
||||
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1, 1, 1, 1});
|
||||
auto gconv = std::make_shared<ngraph::opset1::GroupConvolution>(data,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
|
||||
pattern::has_class<opset1::GroupConvolution>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution> (m.get_match_root());
|
||||
@ -97,23 +85,8 @@ void ngraph::pass::ConvertConvolutions::convert_group_convolution() {
|
||||
}
|
||||
|
||||
void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 12, 12});
|
||||
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 1, 1});
|
||||
auto conv = std::make_shared<ngraph::opset1::ConvolutionBackpropData>(data,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
|
||||
auto output_shape = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
|
||||
auto conv2 = std::make_shared<ngraph::opset1::ConvolutionBackpropData>(data,
|
||||
weights,
|
||||
output_shape,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
|
||||
pattern::has_class<opset1::ConvolutionBackpropData>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
auto deconv = std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData> (m.get_match_root());
|
||||
@ -124,12 +97,14 @@ void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
|
||||
auto deconv_ie = std::make_shared<ngraph::op::DeconvolutionIE>(deconv->input_value(0),
|
||||
deconv->input_value(1),
|
||||
deconv->get_strides(),
|
||||
deconv->get_dilations(),
|
||||
deconv->get_pads_begin(),
|
||||
deconv->get_pads_end(),
|
||||
deconv->get_dilations(),
|
||||
deconv->output(0).get_shape(),
|
||||
1 /* groups */,
|
||||
deconv->get_auto_pad());
|
||||
deconv->get_auto_pad(),
|
||||
deconv->get_output_padding(),
|
||||
(deconv->inputs().size() == 3 ? deconv->input_value(2).get_node_shared_ptr()
|
||||
: nullptr));
|
||||
deconv_ie->set_friendly_name(deconv->get_friendly_name());
|
||||
ngraph::copy_runtime_info(deconv, deconv_ie);
|
||||
ngraph::replace_node(deconv, deconv_ie);
|
||||
@ -138,29 +113,11 @@ void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvertConvolutionBackpropData");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
auto m2 = std::make_shared<ngraph::pattern::Matcher>(conv2, "ConvertConvolutionBackpropData2");
|
||||
this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
|
||||
void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3, 12, 12});
|
||||
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1, 1, 1, 1});
|
||||
auto gconv = std::make_shared<ngraph::opset1::GroupConvolutionBackpropData>(data,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
|
||||
auto output_shape = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
|
||||
auto gconv2 = std::make_shared<ngraph::opset1::GroupConvolutionBackpropData>(data,
|
||||
weights,
|
||||
output_shape,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
|
||||
pattern::has_class<opset1::GroupConvolutionBackpropData>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData> (m.get_match_root());
|
||||
@ -182,12 +139,14 @@ void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data(
|
||||
auto conv_ie = std::make_shared<ngraph::op::DeconvolutionIE>(gconv->input_value(0),
|
||||
reshape,
|
||||
gconv->get_strides(),
|
||||
gconv->get_dilations(),
|
||||
gconv->get_pads_begin(),
|
||||
gconv->get_pads_end(),
|
||||
gconv->get_dilations(),
|
||||
gconv->output(0).get_shape(),
|
||||
group,
|
||||
gconv->get_auto_pad());
|
||||
gconv->get_auto_pad(),
|
||||
gconv->get_output_padding(),
|
||||
(gconv->inputs().size() == 3 ? gconv->input_value(2).get_node_shared_ptr()
|
||||
: nullptr));
|
||||
conv_ie->set_friendly_name(gconv->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gconv, conv_ie);
|
||||
ngraph::replace_node(gconv, conv_ie);
|
||||
@ -196,7 +155,4 @@ void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data(
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolutionBackpropData");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
auto m2 = std::make_shared<ngraph::pattern::Matcher>(gconv2, "ConvertGroupConvolutionBackpropData2");
|
||||
this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ using namespace testing;
|
||||
using InputShape = ngraph::PartialShape;
|
||||
using WeightsShape = ngraph::Shape;
|
||||
|
||||
class ConvertConvolutionsTest: public CommonTestUtils::TestsCommon,
|
||||
class ConvertConvolutionTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, WeightsShape> > {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
@ -66,7 +66,7 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvertConvolutionsTest, CompareFunctions) {
|
||||
TEST_P(ConvertConvolutionTest, CompareFunctions) {
|
||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertConvolutions().run_on_function(f);
|
||||
@ -76,7 +76,7 @@ TEST_P(ConvertConvolutionsTest, CompareFunctions) {
|
||||
ASSERT_TRUE(orig_shape.same_scheme(f->get_output_partial_shape(0))) << "Shape " << orig_shape << " is not equal to " << f->get_output_partial_shape(0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertConvolution, ConvertConvolutionsTest,
|
||||
INSTANTIATE_TEST_CASE_P(ConvertConvolution, ConvertConvolutionTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, WeightsShape{8, 3, 1, 2, 3}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, WeightsShape{8, 3, 1, 2, 3}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, WeightsShape{9, 3, 2, 3, 1}),
|
||||
|
@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/algebraic_simplification.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
|
||||
#include <ngraph_ops/convolution_ie.hpp>
|
||||
#include <ngraph_ops/deconvolution_ie.hpp>
|
||||
|
||||
#include "ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
using InputShape = ngraph::PartialShape;
|
||||
using WeightsShape = ngraph::Shape;
|
||||
|
||||
class ConvertDeconvolutionTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, WeightsShape> > {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& weights_shape = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, weights_shape);
|
||||
f_ref = get_reference_function(input_shape, weights_shape);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
|
||||
const ngraph::Shape & weights_shape) {
|
||||
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, weights_shape, {1});
|
||||
auto conv = std::make_shared<ngraph::opset1::ConvolutionBackpropData>(input, weights, ngraph::Strides(spatial_dims, 1),
|
||||
ngraph::CoordinateDiff(spatial_dims, 0), ngraph::CoordinateDiff(spatial_dims, 0), ngraph::Strides(spatial_dims, 1));
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{conv}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
|
||||
const ngraph::Shape & weights_shape) {
|
||||
auto spatial_dims = input_shape.rank().get_length() - 2;
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, weights_shape, {1});
|
||||
auto conv = std::make_shared<ngraph::op::DeconvolutionIE>(input, weights, ngraph::Strides(spatial_dims, 1), ngraph::Strides(spatial_dims, 1),
|
||||
ngraph::CoordinateDiff(spatial_dims, 0), ngraph::CoordinateDiff(spatial_dims, 0));
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{conv}, ngraph::ParameterVector{input});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvertDeconvolutionTest, CompareFunctions) {
|
||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertConvolutions().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
ASSERT_TRUE(orig_shape.same_scheme(f->get_output_partial_shape(0))) << "Shape " << orig_shape << " is not equal to " << f->get_output_partial_shape(0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertDeconvolution, ConvertDeconvolutionTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, WeightsShape{3, 8, 1, 2, 3}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, WeightsShape{3, 8, 1, 2, 3}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, WeightsShape{3, 9, 2, 3, 1}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64, 64}, WeightsShape{3, 6, 3, 4, 2}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, WeightsShape{3, 5, 3, 4, 3}),
|
||||
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, WeightsShape{3, 3, 3, 4, 3}),
|
||||
std::make_tuple(InputShape{1, 3, 64, 64}, WeightsShape{3, 6, 1, 1}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, WeightsShape{3, 7, 1, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, WeightsShape{3, 8, 1, 2}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, WeightsShape{3, 9, 2, 3}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, WeightsShape{3, 6, 3, 4}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN}, WeightsShape{3, 5, 3, 4}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, WeightsShape{3, 5, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, WeightsShape{3, 3, 1}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, WeightsShape{3, 2, 2}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, WeightsShape{3, 1, 3})));
|
Loading…
Reference in New Issue
Block a user