[GNA] Depth-wise separable convolution support (#7281)
* [GNA] Add support for DWSC, other fixes and code refactoring. * [GNA] Change supported layout to NHWC * [GNA] Detect bias const only on second position, move verification of dwsc to matcher
This commit is contained in:
parent
10f0075e90
commit
57b51701fa
@ -64,9 +64,10 @@
|
||||
#include "transformations/convert_matmul_to_pointwise_convolution.hpp"
|
||||
#include "transformations/split_convolution_with_large_buffer_size.hpp"
|
||||
#include "transformations/handle_transposes_around_matmul.hpp"
|
||||
#include "transformations/decompose_2d_conv.hpp"
|
||||
#include "transformations/convert_padded2valid_conv.hpp"
|
||||
#include "transformations/decompose_2d_convolution.hpp"
|
||||
#include "transformations/convert_padded_to_valid_convolution.hpp"
|
||||
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
|
||||
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
||||
#include "transformations/remove_single_input_concat.hpp"
|
||||
|
||||
@ -716,7 +717,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
manager.register_pass<ngraph::pass::ConvertPriorBox>();
|
||||
manager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||
manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
|
||||
manager.register_pass<ConvertPadded2ValidConv>();
|
||||
manager.register_pass<ConvertDWSCToScaleShifts>();
|
||||
manager.register_pass<ConvertPaddedToValidConv>();
|
||||
if (config.gnaCompileTarget == InferenceEngine::GNAConfigParams::GNA_TARGET_2_0) {
|
||||
manager.register_pass<Decompose2DConvTransposedWithBiasAF>();
|
||||
manager.register_pass<Decompose2DConvTransposedWithBias>();
|
||||
@ -748,7 +750,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
manager.register_pass<RemoveExtraReshapes>();
|
||||
// UnrollTI should be the last transformation in the transformation pipeline
|
||||
manager.register_pass<ngraph::pass::UnrollTensorIterator>();
|
||||
|
||||
const auto& pass_config = manager.get_pass_config();
|
||||
pass_config->disable<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
pass_config->disable<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
|
@ -0,0 +1,207 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ie_common.h>
|
||||
#include "utils/transformation_helper.hpp"
|
||||
|
||||
|
||||
using namespace GNAPluginNS;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ConvertDWSCToScaleShifts, "ConvertDWSCToScaleShifts", 0);
|
||||
|
||||
static std::shared_ptr<ngraph::Node> DecomposeDWSC(std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc,
|
||||
std::shared_ptr<ngraph::opset7::Constant> bias_const, std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias,
|
||||
std::shared_ptr<ngraph::opset7::Reshape> flat_input_plane, std::shared_ptr<ngraph::Node> flat_filters_plane) {
|
||||
std::shared_ptr<ngraph::opset7::Constant> const_zero_padding;
|
||||
std::shared_ptr<ngraph::Node> reshaped_bias;
|
||||
ngraph::OutputVector output_chunks;
|
||||
auto input_channel_count = dwsc->get_input_shape(0)[1];
|
||||
auto input_width = dwsc->get_input_shape(0)[3];
|
||||
auto output_width = dwsc->get_output_shape(0)[3];
|
||||
auto filter_width = dwsc->get_input_shape(1)[4];
|
||||
auto pads_begin = dwsc->get_pads_begin()[1];
|
||||
auto stride_width = dwsc->get_strides()[1];
|
||||
auto dilation_width = dwsc->get_dilations()[1];
|
||||
|
||||
// Constant with zero padding
|
||||
if (pads_begin) {
|
||||
const_zero_padding = std::make_shared<ngraph::opset7::Constant>(dwsc->get_element_type(), ngraph::Shape{1, input_channel_count}, 0);
|
||||
copy_runtime_info(dwsc, const_zero_padding);
|
||||
}
|
||||
|
||||
// Reshape bias const
|
||||
if (bias_const) {
|
||||
auto bias_size = shape_size(bias_const->get_shape());
|
||||
reshaped_bias = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(bias_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, bias_size}), false);
|
||||
}
|
||||
|
||||
// Move filter over input performing multiplication and addition (scaleshift), take padding, stride, dilation and bias into account
|
||||
for (int32_t input_position = -pads_begin, o = 0; o < output_width; input_position += stride_width, o++) {
|
||||
std::shared_ptr<ngraph::Node> previous_layer_output, last_layer_output;
|
||||
int32_t filter_end = input_position + filter_width * dilation_width;
|
||||
bool first = true;
|
||||
|
||||
filter_end = filter_end < input_width ? filter_end : input_width;
|
||||
|
||||
for (int32_t filter_pos = input_position, filter_idx = 0; filter_pos < filter_end; filter_pos += dilation_width, filter_idx++) {
|
||||
if (filter_pos >= 0) {
|
||||
auto conv_input_slice = FlatCrop(flat_input_plane, filter_pos * input_channel_count, input_channel_count);
|
||||
auto conv_filter_slice = FlatCrop(flat_filters_plane, filter_idx * input_channel_count, input_channel_count);
|
||||
|
||||
if (first) {
|
||||
first = false;
|
||||
previous_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
|
||||
copy_runtime_info(dwsc, previous_layer_output);
|
||||
if (bias_const) {
|
||||
previous_layer_output = std::make_shared<ngraph::opset7::Add>(previous_layer_output, reshaped_bias);
|
||||
copy_runtime_info(dwsc, previous_layer_output);
|
||||
previous_layer_output = InsertFQLayer(fq_bias, previous_layer_output);
|
||||
}
|
||||
last_layer_output = previous_layer_output;
|
||||
} else {
|
||||
last_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
|
||||
copy_runtime_info(dwsc, last_layer_output);
|
||||
last_layer_output = std::make_shared<ngraph::opset7::Add>(last_layer_output, previous_layer_output);
|
||||
copy_runtime_info(dwsc, last_layer_output);
|
||||
previous_layer_output = last_layer_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last_layer_output) {
|
||||
IE_ASSERT(const_zero_padding);
|
||||
last_layer_output = const_zero_padding;
|
||||
}
|
||||
|
||||
output_chunks.push_back(last_layer_output);
|
||||
}
|
||||
|
||||
// Concat is only needed when output width > 1
|
||||
if (output_chunks.size() > 1) {
|
||||
auto concat_output_plane = std::make_shared<ngraph::opset7::Concat>(output_chunks, 0);
|
||||
copy_runtime_info(dwsc, concat_output_plane);
|
||||
return concat_output_plane;
|
||||
}
|
||||
|
||||
return output_chunks[0].get_node_shared_ptr();
|
||||
}
|
||||
|
||||
static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
std::shared_ptr<ngraph::Node> dwsc_node,
|
||||
std::shared_ptr<ngraph::Node> bias_const_node,
|
||||
std::shared_ptr<ngraph::Node> fq_bias_node,
|
||||
std::shared_ptr<ngraph::Node> trailing_transpose) {
|
||||
auto dwsc = std::dynamic_pointer_cast<ngraph::opset7::GroupConvolution>(dwsc_node);
|
||||
auto bias_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias_const_node);
|
||||
auto fq_bias = std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(fq_bias_node);
|
||||
|
||||
// We are looking for Transpose(NHWC->NCHW) => GroupConv => Transpose(NCHW->NHWC)
|
||||
// or similar cases, so required network must be in NHWC order like in TF
|
||||
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(leading_transpose), {0, 3, 1, 2}))
|
||||
return false;
|
||||
|
||||
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
|
||||
return false;
|
||||
|
||||
auto output_channel_count = dwsc->get_output_shape(0)[1];
|
||||
auto output_width = dwsc->get_output_shape(0)[3];
|
||||
|
||||
// Prepare flat input data
|
||||
auto flat_input_plane = std::make_shared<ngraph::opset7::Reshape>(leading_transpose->input_value(0),
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
ngraph::Shape{1, shape_size(dwsc->input_value(0).get_shape())}), false);
|
||||
|
||||
// Prepare flat filter data
|
||||
auto filters_const = std::dynamic_pointer_cast<ngraph::Node>(dwsc->get_input_node_shared_ptr(1));
|
||||
auto filters_size = shape_size(filters_const->get_shape());
|
||||
|
||||
auto transposed_filters_const = ngraph::op::util::make_try_fold<ngraph::opset7::Transpose>(filters_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5}, ngraph::Shape{4, 1, 2, 3, 0}));
|
||||
|
||||
auto flat_filters_plane = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(transposed_filters_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, filters_size}), false);
|
||||
|
||||
copy_runtime_info(dwsc, {flat_input_plane, transposed_filters_const, flat_filters_plane});
|
||||
|
||||
// Convert DWSC to a set of diagonal layers
|
||||
auto output_plane = DecomposeDWSC(dwsc, bias_const, fq_bias, flat_input_plane, flat_filters_plane);
|
||||
|
||||
// Restore the original output shape
|
||||
auto result = std::make_shared<ngraph::opset7::Reshape>(output_plane,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||
ngraph::Shape{1, output_channel_count, 1, output_width}), false);
|
||||
copy_runtime_info(dwsc, result);
|
||||
|
||||
// We need to put here the original Group Convolution layer name, so the new layer output can be used as a network result
|
||||
std::string result_name = trailing_transpose->get_friendly_name();
|
||||
replace_node(trailing_transpose, result);
|
||||
result->set_friendly_name(result_name);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool VerifyDWSC(const ngraph::Output<ngraph::Node>& output) {
|
||||
auto dwsc = output.get_node();
|
||||
|
||||
// Verify it's a 1D convolution
|
||||
// Verify that filter group count == input channel count
|
||||
// Verify that per group filter output channel count == 1
|
||||
if (!consumers_and_rank(1, 4)(output) ||
|
||||
dwsc->get_input_shape(1)[3] != 1 || dwsc->get_input_shape(0)[2] != 1 || dwsc->get_output_shape(0)[2] != 1 ||
|
||||
dwsc->get_input_shape(1)[0] != dwsc->get_input_shape(0)[1] ||
|
||||
dwsc->get_input_shape(1)[1] != 1)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ConvertDWSCToScaleShifts::ConvertDWSCToScaleShifts() {
|
||||
MATCHER_SCOPE(ConvertDWSCToScaleShifts);
|
||||
|
||||
auto const_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto leading_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({ngraph::pattern::any_input(), const_input},
|
||||
consumers_and_rank(1, 4));
|
||||
auto filters_const_fq = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(ngraph::pattern::rank_equals(4));
|
||||
auto fq_filters_const = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({filters_const_fq, const_input, const_input, const_input, const_input},
|
||||
consumers_and_rank(1, 4));
|
||||
auto reshape_filters_const = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({fq_filters_const, const_input}, ngraph::pattern::rank_equals(5));
|
||||
auto filters_const = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(ngraph::pattern::rank_equals(5));
|
||||
auto dwsc_filters = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{filters_const, reshape_filters_const });
|
||||
auto dwsc = ngraph::pattern::wrap_type<ngraph::opset7::GroupConvolution>({leading_transpose, dwsc_filters}, VerifyDWSC);
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Add>({dwsc, const_input});
|
||||
auto fq_bias = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({bias, const_input, const_input, const_input, const_input},
|
||||
consumers_and_rank(1, 4));
|
||||
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{dwsc, bias, fq_bias});
|
||||
auto trailing_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, const_input}, consumers_and_rank(1, 4));
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto bias_it = pattern_map.find(bias);
|
||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
||||
std::shared_ptr<ngraph::Node> bias_const = nullptr;
|
||||
|
||||
if (bias_node && (bias_const = VerifyBiasGetConst(pattern_map.at(dwsc).get_node_shared_ptr(), bias_node)) == nullptr)
|
||||
return false;
|
||||
|
||||
auto fq_bias_it = pattern_map.find(fq_bias);
|
||||
auto fq_bias_node = (fq_bias_it == std::end(pattern_map) ? nullptr : fq_bias_it->second.get_node_shared_ptr());
|
||||
|
||||
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), pattern_map.at(dwsc).get_node_shared_ptr(),
|
||||
bias_const, fq_bias_node,
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr());
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(trailing_transpose, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
/**
|
||||
* @brief Convert a depthwise separable convolution (represented by a GroupConvolution) to a set of ScaleShift layers (MatMul + Add)
|
||||
* Additionally supported are bias and fake quantize layers.
|
||||
*/
|
||||
class ConvertDWSCToScaleShifts : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertDWSCToScaleShifts();
|
||||
};
|
||||
|
||||
} // namespace GNAPluginNS
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include "transformations/convert_padded2valid_conv.hpp"
|
||||
#include "transformations/convert_padded_to_valid_convolution.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
|
||||
using namespace GNAPluginNS;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ConvertPadded2ValidConv, "ConvertPadded2ValidConv", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ConvertPaddedToValidConv, "ConvertPaddedToValidConv", 0);
|
||||
|
||||
static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> conv, ConvData& conv_data) {
|
||||
const auto& input = conv->input_value(0);
|
||||
@ -34,17 +34,6 @@ static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> co
|
||||
return conv_data.pads_begin_height || conv_data.pads_end_height || conv_data.pads_begin_width || conv_data.pads_end_width;
|
||||
}
|
||||
|
||||
static bool VerifyBias(std::shared_ptr<ngraph::opset7::Add> bias, const size_t& filter_count) {
|
||||
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(0).get_node_shared_ptr());
|
||||
|
||||
// We need to check both inputs of Add when looking for constant
|
||||
if (!add_const)
|
||||
add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
|
||||
|
||||
// The add may be a normal add not convolution bias, then we just go further
|
||||
return (add_const && shape_size(add_const->get_shape()) == filter_count);
|
||||
}
|
||||
|
||||
static void InsertPadding(ngraph::OutputVector& input_rows_to_concat, size_t size, const std::shared_ptr<ngraph::opset7::Convolution>& conv,
|
||||
const std::shared_ptr<ngraph::opset7::Constant> padding_const, size_t biggest_padding) {
|
||||
|
||||
@ -181,9 +170,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
|
||||
return false;
|
||||
|
||||
if (bias && !VerifyBias(std::dynamic_pointer_cast<ngraph::opset7::Add>(bias), conv_data.filter_count))
|
||||
return false;
|
||||
|
||||
GeneratePadding(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(leading_transpose),
|
||||
std::dynamic_pointer_cast<ngraph::opset7::Convolution>(conv), conv_data);
|
||||
|
||||
@ -196,8 +182,8 @@ static std::function<bool(ngraph::Output<ngraph::Node>)> consumers_and_rank(cons
|
||||
};
|
||||
}
|
||||
|
||||
ConvertPadded2ValidConv::ConvertPadded2ValidConv() {
|
||||
MATCHER_SCOPE(ConvertPadded2ValidConv);
|
||||
ConvertPaddedToValidConv::ConvertPaddedToValidConv() {
|
||||
MATCHER_SCOPE(ConvertPaddedToValidConv);
|
||||
|
||||
auto const_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto leading_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({ngraph::pattern::any_input(), const_input},
|
||||
@ -237,6 +223,9 @@ ConvertPadded2ValidConv::ConvertPadded2ValidConv() {
|
||||
auto bias_it = pattern_map.find(bias);
|
||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
||||
|
||||
if (bias_node && !VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), bias_node))
|
||||
return false;
|
||||
|
||||
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), pattern_map.at(conv).get_node_shared_ptr(),
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node);
|
||||
};
|
@ -28,10 +28,10 @@ namespace GNAPluginNS {
|
||||
* Transpose (NCHW -> NHWC) Transpose (NCHW -> NHWC)
|
||||
*
|
||||
*/
|
||||
class ConvertPadded2ValidConv : public ngraph::pass::MatcherPass {
|
||||
class ConvertPaddedToValidConv : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPadded2ValidConv();
|
||||
ConvertPaddedToValidConv();
|
||||
};
|
||||
|
||||
} // namespace GNAPluginNS
|
@ -4,9 +4,7 @@
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include "transformations/decompose_2d_conv.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include "transformations/decompose_2d_convolution.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
@ -68,22 +66,6 @@ static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> co
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Node> VerifyBiasAndReshapeConst(std::shared_ptr<ngraph::opset7::Add> conv_bias, const ConvData& conv_data) {
|
||||
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(conv_bias->input_value(1).get_node_shared_ptr());
|
||||
|
||||
if (add_const) {
|
||||
auto bias_size = shape_size(add_const->get_shape());
|
||||
|
||||
// The add may be a normal add not conv bias, then we just go further
|
||||
if (bias_size == conv_data.filter_count) {
|
||||
return ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(add_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, bias_size, 1, 1}), false);
|
||||
}
|
||||
}
|
||||
// Bias size does not match (or dynamic bias), can't decompose such convolution
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool VerifyMaxPool(GraphData& graph_data, std::shared_ptr<ngraph::opset7::MaxPool> max_pool) {
|
||||
auto pool_filter = max_pool->get_kernel();
|
||||
auto pool_strides = max_pool->get_strides();
|
||||
@ -236,7 +218,7 @@ static void TransformInput(const GraphData& graph_data, const ConvData& conv_dat
|
||||
*/
|
||||
|
||||
// First we need to prepare flat (height = 1) slices of input data proper for flattened (height = 1) filters created later on;
|
||||
// the input datat is overlapping (duplicated)
|
||||
// the input data is overlapping (duplicated)
|
||||
ngraph::OutputVector dilated_input_planes;
|
||||
for (size_t filter_height = 0; filter_height < conv_data.filter_height; filter_height++) {
|
||||
size_t offset;
|
||||
@ -280,16 +262,6 @@ static void TransformInput(const GraphData& graph_data, const ConvData& conv_dat
|
||||
split_input_plane = flattened_dilated_transposed_input;
|
||||
}
|
||||
|
||||
static void InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fqLayer,
|
||||
std::shared_ptr<ngraph::Node> lastNode) {
|
||||
if (fqLayer != nullptr) {
|
||||
lastNode = fqLayer->clone_with_new_inputs({lastNode,
|
||||
fqLayer->input_value(1), fqLayer->input_value(2),
|
||||
fqLayer->input_value(3), fqLayer->input_value(4)});
|
||||
ngraph::copy_runtime_info(fqLayer, lastNode);
|
||||
}
|
||||
}
|
||||
|
||||
// Valid 1D (decomposed 2D) convolution wrapped with transposes NHWC => NCHW => conv => NCHW => NHWC
|
||||
static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, const ConvData& conv_data, const ngraph::Output<ngraph::Node>& input,
|
||||
std::shared_ptr<ngraph::Node> filters, const size_t conv_index, const size_t h_index) {
|
||||
@ -298,7 +270,7 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 3, 1, 2})->output(0));
|
||||
|
||||
// Fake quantize
|
||||
InsertFQLayer(graph_data.fq_conv, filters);
|
||||
filters = InsertFQLayer(graph_data.fq_conv, filters);
|
||||
|
||||
// 1D Convolution
|
||||
auto conv = std::make_shared<ngraph::opset7::Convolution>(nchw_input, filters,
|
||||
@ -306,13 +278,16 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
|
||||
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
|
||||
std::string conv_name = graph_data.conv->get_friendly_name() + "_H_" + std::to_string(h_index) + "_CH_" + std::to_string(0);
|
||||
conv->set_friendly_name(conv_name);
|
||||
std::shared_ptr<ngraph::Node> last_conv_block_op = conv;
|
||||
|
||||
// Bias & fake quantize
|
||||
std::shared_ptr<ngraph::Node> last_conv_block_op = conv;
|
||||
if (graph_data.bias_const && conv_index == 0) {
|
||||
last_conv_block_op = std::make_shared<ngraph::opset7::Add>(conv, graph_data.bias_const);
|
||||
auto bias_size = shape_size(graph_data.bias_const->get_shape());
|
||||
auto reshaped_bias_const = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(graph_data.bias_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, bias_size, 1, 1}), false);
|
||||
last_conv_block_op = std::make_shared<ngraph::opset7::Add>(conv, reshaped_bias_const);
|
||||
copy_runtime_info(graph_data.conv, last_conv_block_op);
|
||||
InsertFQLayer(graph_data.fq_bias, last_conv_block_op);
|
||||
last_conv_block_op = InsertFQLayer(graph_data.fq_bias, last_conv_block_op);
|
||||
}
|
||||
|
||||
// Max pooling
|
||||
@ -326,7 +301,7 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
|
||||
if (graph_data.af && graph_data.conv_count == 1) {
|
||||
last_conv_block_op = graph_data.af->copy_with_new_inputs({last_conv_block_op});
|
||||
copy_runtime_info(conv, last_conv_block_op);
|
||||
InsertFQLayer(graph_data.fq_af, last_conv_block_op);
|
||||
last_conv_block_op = InsertFQLayer(graph_data.fq_af, last_conv_block_op);
|
||||
}
|
||||
|
||||
// Transpose NCHW => NHWC
|
||||
@ -472,6 +447,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
std::shared_ptr<ngraph::Node> conv,
|
||||
std::shared_ptr<ngraph::Node> trailing_transpose,
|
||||
std::shared_ptr<ngraph::Node> bias,
|
||||
std::shared_ptr<ngraph::Node> bias_const,
|
||||
std::shared_ptr<ngraph::Node> fq_bias,
|
||||
std::shared_ptr<ngraph::Node> max_pool,
|
||||
std::shared_ptr<ngraph::Node> af,
|
||||
@ -486,7 +462,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
std::dynamic_pointer_cast<ngraph::opset7::MaxPool>(max_pool),
|
||||
std::dynamic_pointer_cast<ngraph::op::util::UnaryElementwiseArithmetic>(af),
|
||||
std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(fq_af),
|
||||
last_op_for_replacement, nullptr, 1, 1, 1};
|
||||
last_op_for_replacement, bias_const, 1, 1, 1};
|
||||
ConvData conv_data;
|
||||
|
||||
if (!VerifyAndGetConvData(std::dynamic_pointer_cast<ngraph::opset7::Convolution>(conv), conv_data))
|
||||
@ -500,9 +476,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
|
||||
return false;
|
||||
|
||||
if (bias && !(graph_data.bias_const = VerifyBiasAndReshapeConst(std::dynamic_pointer_cast<ngraph::opset7::Add>(bias), conv_data)))
|
||||
return false;
|
||||
|
||||
if (max_pool && !VerifyMaxPool(graph_data, std::dynamic_pointer_cast<ngraph::opset7::MaxPool>(max_pool)))
|
||||
return false;
|
||||
|
||||
@ -515,22 +488,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool VerifyBias(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias) {
|
||||
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
|
||||
|
||||
if (!add_const) {
|
||||
add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (!add_const) {
|
||||
auto bias_size = shape_size(add_const->get_shape());
|
||||
auto conv_filter_count = conv->input_value(1).get_shape()[0];
|
||||
if (bias_size == conv_filter_count)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Decompose2DConv::Decompose2DConv() {
|
||||
MATCHER_SCOPE(Decompose2DConv);
|
||||
|
||||
@ -576,6 +533,11 @@ Decompose2DConv::Decompose2DConv() {
|
||||
auto fq_conv_node = (fq_conv_it == std::end(pattern_map) ? nullptr : fq_conv_it->second.get_node_shared_ptr());
|
||||
auto bias_it = pattern_map.find(bias);
|
||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
||||
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
|
||||
|
||||
if (bias_node && !(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), bias_node)))
|
||||
return false;
|
||||
|
||||
auto fq_bias_it = pattern_map.find(fq_bias);
|
||||
auto fq_bias_node = (fq_bias_it == std::end(pattern_map) ? nullptr : fq_bias_it->second.get_node_shared_ptr());
|
||||
auto fq_af_it = pattern_map.find(fq_af);
|
||||
@ -596,7 +558,7 @@ Decompose2DConv::Decompose2DConv() {
|
||||
}
|
||||
|
||||
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), fq_conv_node, pattern_map.at(conv).get_node_shared_ptr(),
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node, fq_bias_node, max_pool_node, af_node, fq_af_node,
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node, bias_const_node, fq_bias_node, max_pool_node, af_node, fq_af_node,
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr());
|
||||
};
|
||||
|
||||
@ -621,11 +583,13 @@ Decompose2DConvTransposedWithBias::Decompose2DConvTransposedWithBias() {
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
if (!VerifyBias(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr()))
|
||||
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
|
||||
|
||||
if (!(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr())))
|
||||
return false;
|
||||
|
||||
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), nullptr, pattern_map.at(conv).get_node_shared_ptr(),
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), nullptr, nullptr,
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), bias_const_node, nullptr, nullptr,
|
||||
nullptr, nullptr, pattern_map.at(bias).get_node_shared_ptr());
|
||||
};
|
||||
|
||||
@ -654,11 +618,13 @@ Decompose2DConvTransposedWithBiasAF::Decompose2DConvTransposedWithBiasAF() {
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
if (!VerifyBias(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr()))
|
||||
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
|
||||
|
||||
if (!(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr())))
|
||||
return false;
|
||||
|
||||
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), nullptr, pattern_map.at(conv).get_node_shared_ptr(),
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), nullptr,
|
||||
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), bias_const_node, nullptr,
|
||||
nullptr, pattern_map.at(af).get_node_shared_ptr(), nullptr, pattern_map.at(af).get_node_shared_ptr());
|
||||
};
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include "transformation_helper.hpp"
|
||||
|
||||
|
||||
@ -72,4 +73,29 @@ std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::No
|
||||
std::vector<int64_t>{1, 0}); // end mask
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> VerifyBiasGetConst(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias) {
|
||||
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
|
||||
|
||||
// Check if it's really a bias and not just addition
|
||||
if (add_const) {
|
||||
auto bias_size = shape_size(add_const->get_shape());
|
||||
auto conv_filter_count = conv->get_output_shape(0)[1];
|
||||
if (bias_size == conv_filter_count)
|
||||
return add_const;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer,
|
||||
std::shared_ptr<ngraph::Node> last_node) {
|
||||
if (fq_layer != nullptr) {
|
||||
auto new_fq = fq_layer->clone_with_new_inputs({last_node,
|
||||
fq_layer->input_value(1), fq_layer->input_value(2),
|
||||
fq_layer->input_value(3), fq_layer->input_value(4)});
|
||||
ngraph::copy_runtime_info(new_fq, fq_layer);
|
||||
return new_fq;
|
||||
}
|
||||
return last_node;
|
||||
}
|
||||
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -61,4 +61,21 @@ bool TransposeOrderMatches(std::shared_ptr<ngraph::opset7::Transpose> transpose,
|
||||
* @return pointer to the newly created slice
|
||||
*/
|
||||
std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::Node> input, size_t offset, size_t size);
|
||||
|
||||
/**
|
||||
* @brief checks whether an add present after convolution is a bias and gets its const input
|
||||
* @param conv convolution layer preceding potential bias
|
||||
* @param bias potential bias layer passed from ngraph matcher
|
||||
* @return bias const if the add layer present after convolution is a bias, nullptr otherwise
|
||||
*/
|
||||
std::shared_ptr<ngraph::Node> VerifyBiasGetConst(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias);
|
||||
|
||||
/**
|
||||
* @brief inserts a new fake quantize layer (if it exists) copied from an existing fake quantize layer and conncts it to the output of a given layer
|
||||
* @param fq_layer existing fake quantize layer to be copied
|
||||
* @param last_node the node to which output the new fake quantize layer will be connected
|
||||
* @return new fake quantize layer or the last node
|
||||
*/
|
||||
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer, std::shared_ptr<ngraph::Node> last_node);
|
||||
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -0,0 +1,215 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::opset7;
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
enum class modelType {
|
||||
TranspDWSCTransp = 0, /* Transpose(NHWC->NCHW) => DWSC (Group Convolution) => Transpose(NCHW->NHWC) */
|
||||
TranspDWSCBiasTransp, /* Transpose(NHWC->NCHW) => DWSC => Broadcasted Add (Bias) => Transpose(NCHW->NHWC) */
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
InferenceEngine::SizeVector, // Kernel size
|
||||
InferenceEngine::SizeVector, // Strides
|
||||
std::vector<ptrdiff_t>, // Pad begin
|
||||
std::vector<ptrdiff_t>, // Pad end
|
||||
InferenceEngine::SizeVector, // Dilation
|
||||
op::PadType, // Padding type
|
||||
size_t, // Num out channels
|
||||
size_t, // Num groups
|
||||
InferenceEngine::SizeVector // Bias
|
||||
> DWSCParams;
|
||||
|
||||
typedef std::tuple<
|
||||
DWSCParams, // DWSC and bias parameters
|
||||
InferenceEngine::Precision, // Network Precision
|
||||
std::string, // Target Device
|
||||
std::map<std::string, std::string>, // Configuration
|
||||
InferenceEngine::SizeVector, // Input shapes
|
||||
modelType // Test model
|
||||
> DWSCToScaleShiftsParams;
|
||||
|
||||
class DWSCToScaleShiftsTest : public testing::WithParamInterface<DWSCToScaleShiftsParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<DWSCToScaleShiftsParams> obj) {
|
||||
DWSCParams params;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> configuration;
|
||||
InferenceEngine::SizeVector inputShape;
|
||||
modelType model;
|
||||
std::tie(params, netPrecision, targetDevice, configuration, inputShape, model) = obj.param;
|
||||
op::PadType padType;
|
||||
InferenceEngine::SizeVector filter, stride, dilation, bias;
|
||||
std::vector<ptrdiff_t> padBegin, padEnd;
|
||||
size_t numOutChannels, numGroups;
|
||||
std::tie(filter, stride, padBegin, padEnd, dilation, padType, numOutChannels, numGroups, bias) = params;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "M=" << static_cast<uint32_t>(model) << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
|
||||
result << "K" << CommonTestUtils::vec2str(filter) << "_";
|
||||
result << "S" << CommonTestUtils::vec2str(stride) << "_";
|
||||
result << "PB" << CommonTestUtils::vec2str(padBegin) << "_";
|
||||
result << "PE" << CommonTestUtils::vec2str(padEnd) << "_";
|
||||
result << "D=" << CommonTestUtils::vec2str(dilation) << "_";
|
||||
result << "O=" << numOutChannels << "_";
|
||||
result << "AP=" << padType << "_";
|
||||
result << "B=" << CommonTestUtils::vec2str(bias) << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
for (auto const& configItem : configuration) {
|
||||
result << "_configItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
threshold = 0.05f;
|
||||
DWSCParams params;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::vector<size_t> inputShape;
|
||||
modelType model;
|
||||
std::tie(params, netPrecision, targetDevice, configuration, inputShape, model) = this->GetParam();
|
||||
op::PadType padType;
|
||||
InferenceEngine::SizeVector filter, stride, dilation, bias;
|
||||
std::vector<ptrdiff_t> padBegin, padEnd;
|
||||
size_t numOutChannels, numGroups;
|
||||
std::tie(filter, stride, padBegin, padEnd, dilation, padType, numOutChannels, numGroups, bias) = params;
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto input = builder::makeParams(ngPrc, {inputShape});
|
||||
auto transposeInOrder = op::Constant::create(element::i64, Shape{4}, {0, 3, 1, 2});
|
||||
auto transposeIn = std::make_shared<Transpose>(input[0], transposeInOrder);
|
||||
auto filterSize = std::accumulate(std::begin(filter), std::end(filter), 1ull, std::multiplies<size_t>());
|
||||
auto filterWeights = CommonTestUtils::generate_float_numbers(numOutChannels * (inputShape[3] / numGroups) * filterSize, -0.5f, 0.5f);
|
||||
auto dwsc = builder::makeGroupConvolution(transposeIn, ngPrc, filter, stride, padBegin,
|
||||
padEnd, dilation, padType, numOutChannels, numGroups, false, filterWeights);
|
||||
auto transposeOutOrder = op::Constant::create(element::i64, Shape{4}, {0, 2, 3, 1});
|
||||
auto lastOp = std::make_shared<Transpose>(dwsc, transposeOutOrder);
|
||||
|
||||
if (model == modelType::TranspDWSCBiasTransp) {
|
||||
Shape biasShape{bias};
|
||||
auto biasWeights = CommonTestUtils::generate_float_numbers(shape_size(biasShape), -1.0f, 1.0f);
|
||||
auto biasConst = std::make_shared<Constant>(ngPrc, biasShape, biasWeights);
|
||||
auto bias = std::make_shared<Add>(dwsc, biasConst);
|
||||
lastOp = std::make_shared<Transpose>(bias, transposeOutOrder);
|
||||
}
|
||||
|
||||
auto result = std::make_shared<Result>(lastOp);
|
||||
function = std::make_shared<Function>(ResultVector{result}, ParameterVector{input});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(DWSCToScaleShiftsTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "1"},
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<op::PadType> padTypes = {
|
||||
op::PadType::VALID,
|
||||
op::PadType::EXPLICIT,
|
||||
op::PadType::SAME_LOWER,
|
||||
op::PadType::SAME_UPPER
|
||||
};
|
||||
|
||||
const std::vector<modelType> models = {
|
||||
modelType::TranspDWSCTransp,
|
||||
modelType::TranspDWSCBiasTransp
|
||||
};
|
||||
|
||||
const std::vector<std::vector<size_t>> inputNHWC = {{1, 1, 5, 32}};
|
||||
const std::vector<std::vector<size_t >> filters = {{1, 3}};
|
||||
const std::vector<std::vector<size_t >> strides = {{1, 1}, {1, 2}};
|
||||
const std::vector<std::vector<ptrdiff_t>> padBegins = {{0, 1}, {0, 2}};
|
||||
const std::vector<std::vector<ptrdiff_t>> padEnds = {{0, 1}};
|
||||
const std::vector<std::vector<size_t >> dilations = {{1, 1}};
|
||||
const std::vector<size_t> numOutChannels = {32};
|
||||
const std::vector<size_t> numGroups = {32};
|
||||
const std::vector<std::vector<size_t >> biases = {{1, 32, 1, 1}};
|
||||
|
||||
const auto convParams = ::testing::Combine(
|
||||
::testing::ValuesIn(filters),
|
||||
::testing::ValuesIn(strides),
|
||||
::testing::ValuesIn(padBegins),
|
||||
::testing::ValuesIn(padEnds),
|
||||
::testing::ValuesIn(dilations),
|
||||
::testing::ValuesIn(padTypes),
|
||||
::testing::ValuesIn(numOutChannels),
|
||||
::testing::ValuesIn(numGroups),
|
||||
::testing::ValuesIn(biases)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_DWSCToScaleShifts, DWSCToScaleShiftsTest,
|
||||
::testing::Combine(
|
||||
convParams,
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(inputNHWC),
|
||||
::testing::ValuesIn(models)),
|
||||
DWSCToScaleShiftsTest::getTestCaseName);
|
||||
|
||||
/* ============= Strides & Dilations Combination ============= */
|
||||
|
||||
const std::vector<op::PadType> padTypesSD = {
|
||||
op::PadType::VALID,
|
||||
};
|
||||
|
||||
const std::vector<std::vector<size_t>> inputNHWCSD = {{1, 1, 8, 32}};
|
||||
const std::vector<std::vector<size_t >> dilationsSD = {{1, 1}, {1, 2}};
|
||||
|
||||
const auto convParamsSD = ::testing::Combine(
|
||||
::testing::ValuesIn(filters),
|
||||
::testing::ValuesIn(strides),
|
||||
::testing::ValuesIn(padBegins),
|
||||
::testing::ValuesIn(padEnds),
|
||||
::testing::ValuesIn(dilationsSD),
|
||||
::testing::ValuesIn(padTypesSD),
|
||||
::testing::ValuesIn(numOutChannels),
|
||||
::testing::ValuesIn(numGroups),
|
||||
::testing::ValuesIn(biases)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_DWSCToScaleShiftsStridesDilations, DWSCToScaleShiftsTest,
|
||||
::testing::Combine(
|
||||
convParamsSD,
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(inputNHWCSD),
|
||||
::testing::ValuesIn(models)),
|
||||
DWSCToScaleShiftsTest::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -57,12 +57,12 @@ typedef std::tuple<
|
||||
std::map<std::string, std::string>, // Configuration
|
||||
InferenceEngine::SizeVector, // Input shapes
|
||||
modelType // Test model
|
||||
> padded2ValidParams;
|
||||
> paddedToValidParams;
|
||||
|
||||
class Padded2ValidConvTest : public testing::WithParamInterface<padded2ValidParams>,
|
||||
class PaddedToValidConvTest : public testing::WithParamInterface<paddedToValidParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<padded2ValidParams> obj) {
|
||||
static std::string getTestCaseName(testing::TestParamInfo<paddedToValidParams> obj) {
|
||||
convSpecificParams convParams;
|
||||
miscSpecificParams miscParams;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
@ -195,26 +195,26 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class Gna30Padded2ValidConvTest : public Padded2ValidConvTest, GnaLayerTestCheck {
|
||||
class Gna30PaddedToValidConvTest : public PaddedToValidConvTest, GnaLayerTestCheck {
|
||||
protected:
|
||||
void Run() override {
|
||||
GnaLayerTestCheck::SkipTestCheck();
|
||||
|
||||
if (!GnaLayerTestCheck::skipTest) {
|
||||
Padded2ValidConvTest::Run();
|
||||
PaddedToValidConvTest::Run();
|
||||
}
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
Padded2ValidConvTest::SetUp();
|
||||
PaddedToValidConvTest::SetUp();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(Padded2ValidConvTest, CompareWithRefs) {
|
||||
TEST_P(PaddedToValidConvTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(Gna30Padded2ValidConvTest, CompareWithRefs) {
|
||||
TEST_P(Gna30PaddedToValidConvTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
@ -322,7 +322,7 @@ const auto misc2DParams = ::testing::Combine(
|
||||
::testing::ValuesIn(maxpool2DStrides)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Padded2ValidConvTest,
|
||||
INSTANTIATE_TEST_CASE_P(smoke_1DPaddedToValid, PaddedToValidConvTest,
|
||||
::testing::Combine(
|
||||
conv1DParams,
|
||||
misc1DParams,
|
||||
@ -331,9 +331,9 @@ INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Padded2ValidConvTest,
|
||||
::testing::ValuesIn(configs1D),
|
||||
::testing::ValuesIn(input1DNHWC),
|
||||
::testing::ValuesIn(models)),
|
||||
Padded2ValidConvTest::getTestCaseName);
|
||||
PaddedToValidConvTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Gna30Padded2ValidConvTest,
|
||||
INSTANTIATE_TEST_CASE_P(smoke_1DPaddedToValid, Gna30PaddedToValidConvTest,
|
||||
::testing::Combine(
|
||||
conv1DParams,
|
||||
misc1DParams,
|
||||
@ -342,9 +342,9 @@ INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Gna30Padded2ValidConvTest,
|
||||
::testing::ValuesIn(configs1D_Gna30),
|
||||
::testing::ValuesIn(input1DNHWC),
|
||||
::testing::ValuesIn(models)),
|
||||
Gna30Padded2ValidConvTest::getTestCaseName);
|
||||
Gna30PaddedToValidConvTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_2DPadded2Valid, Gna30Padded2ValidConvTest,
|
||||
INSTANTIATE_TEST_CASE_P(smoke_2DPaddedToValid, Gna30PaddedToValidConvTest,
|
||||
::testing::Combine(
|
||||
conv2DParams,
|
||||
misc2DParams,
|
||||
@ -353,6 +353,6 @@ INSTANTIATE_TEST_CASE_P(smoke_2DPadded2Valid, Gna30Padded2ValidConvTest,
|
||||
::testing::ValuesIn(configs2D),
|
||||
::testing::ValuesIn(input2DNHWC),
|
||||
::testing::ValuesIn(models)),
|
||||
Gna30Padded2ValidConvTest::getTestCaseName);
|
||||
Gna30PaddedToValidConvTest::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,384 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
namespace testing {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class modelType {
|
||||
TranspDWSCTransp = 0, /* Transpose(NHWC->NCHW) => DWSC (Group Convolution) => Transpose(NCHW->NHWC) */
|
||||
TranspDWSCBiasTransp, /* Transpose(NHWC->NCHW) => DWSC => Broadcasted Add (Bias) => Transpose(NCHW->NHWC) */
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
modelType, // Test model
|
||||
ngraph::Shape, // Input shape
|
||||
ngraph::Shape, // Convolution filter shape
|
||||
ngraph::Strides, // Convolution stride
|
||||
ngraph::CoordinateDiff, // Convolution pads begin
|
||||
ngraph::CoordinateDiff, // Convolution pads end
|
||||
ngraph::Strides, // Convolution dilation
|
||||
ngraph::Shape, // Bias shape
|
||||
ngraph::op::PadType // Padding type
|
||||
> DWSCToScaleShiftsParams;
|
||||
|
||||
typedef std::tuple<
|
||||
bool, // With / without Fake Quantize layers
|
||||
DWSCToScaleShiftsParams // Test parameters
|
||||
> fqDWSCToScaleShiftsParams;
|
||||
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize> createFQ(std::shared_ptr<ngraph::Node>& in_node) {
|
||||
auto input_low = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {5});
|
||||
auto output_low = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
|
||||
return std::make_shared<ngraph::opset7::FakeQuantize>(in_node, input_low, input_high, output_low, output_high, 11);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> createBiasFQ(const std::shared_ptr<ngraph::Node>& in_node,
|
||||
std::shared_ptr<ngraph::opset7::Constant>& bias_const, const bool& fq) {
|
||||
std::shared_ptr<ngraph::Node> node;
|
||||
node = std::make_shared<ngraph::opset7::Add>(in_node, bias_const);
|
||||
|
||||
if (fq) {
|
||||
node = createFQ(node);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset7::Result> createFunction(const bool& fq,
|
||||
const modelType& model,
|
||||
const ngraph::Output<ngraph::Node>& input_node,
|
||||
const ngraph::Shape& filters_shape,
|
||||
const ngraph::Strides& conv_stride,
|
||||
const ngraph::CoordinateDiff& pads_begin,
|
||||
const ngraph::CoordinateDiff& pads_end,
|
||||
const ngraph::Strides& conv_dilation,
|
||||
const ngraph::Shape& bias_shape,
|
||||
const ngraph::op::PadType& pad_type,
|
||||
std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
|
||||
std::shared_ptr<ngraph::opset7::Constant>& bias_const,
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
|
||||
std::shared_ptr<ngraph::Node> fq_filters;
|
||||
|
||||
auto transpose_in_order = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, ngraph::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
|
||||
auto transpose_in = std::make_shared<ngraph::opset7::Transpose>(input_node, transpose_in_order);
|
||||
|
||||
if (fq) {
|
||||
fq_filters = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{input_node.get_shape()[3], 1, filters_shape[0], filters_shape[1]});
|
||||
fq_filters = createFQ(fq_filters);
|
||||
fq_filters = std::make_shared<ngraph::opset7::Reshape>(fq_filters,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5},
|
||||
ngraph::Shape{input_node.get_shape()[3], 1, 1, filters_shape[0], filters_shape[1]}), false);
|
||||
} else {
|
||||
fq_filters = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{input_node.get_shape()[3], 1, 1, filters_shape[0], filters_shape[1]});
|
||||
}
|
||||
|
||||
dwsc = std::make_shared<ngraph::opset7::GroupConvolution>(transpose_in, fq_filters, conv_stride, pads_begin, pads_end, conv_dilation, pad_type);
|
||||
auto transpose_out_order = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, ngraph::Shape{4}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto last_op = std::make_shared<ngraph::opset7::Transpose>(dwsc, transpose_out_order);
|
||||
|
||||
if (model == modelType::TranspDWSCBiasTransp || fq) {
|
||||
bias_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, bias_shape);
|
||||
auto bias = createBiasFQ(dwsc, bias_const, fq);
|
||||
fq_bias = std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(bias);
|
||||
last_op = std::make_shared<ngraph::opset7::Transpose>(bias, transpose_out_order);
|
||||
}
|
||||
|
||||
return std::make_shared<ngraph::opset7::Result>(last_op);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> get_initial_function(const bool& fq,
|
||||
const modelType& model,
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& filters_shape,
|
||||
const ngraph::Strides& conv_stride,
|
||||
const ngraph::CoordinateDiff& pads_begin,
|
||||
const ngraph::CoordinateDiff& pads_end,
|
||||
const ngraph::Strides& conv_dilation,
|
||||
const ngraph::Shape& bias_shape,
|
||||
const ngraph::op::PadType& pad_type,
|
||||
std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
|
||||
std::shared_ptr<ngraph::opset7::Constant>& bias_const,
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
auto result = createFunction(fq, model, input_params, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type, dwsc, bias_const, fq_bias);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ConvertDWSCToScaleShiftsTestInvalidFixture : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqDWSCToScaleShiftsParams> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
modelType model;
|
||||
};
|
||||
|
||||
void ConvertDWSCToScaleShiftsTestInvalidFixture::SetUp() {
|
||||
bool fq;
|
||||
DWSCToScaleShiftsParams params;
|
||||
ngraph::Shape input_shape;
|
||||
ngraph::Shape filters_shape, bias_shape;
|
||||
ngraph::Strides conv_stride, conv_dilation;
|
||||
ngraph::CoordinateDiff pads_begin, pads_end;
|
||||
ngraph::op::PadType pad_type;
|
||||
std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc;
|
||||
std::shared_ptr<ngraph::opset7::Constant> bias_const;
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias;
|
||||
std::tie(fq, params) = this->GetParam();
|
||||
std::tie(model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type) = params;
|
||||
|
||||
function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type, dwsc, bias_const, fq_bias);
|
||||
reference_function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type, dwsc, bias_const, fq_bias);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ConvertDWSCToScaleShiftsTestFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqDWSCToScaleShiftsParams> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
|
||||
const modelType& model,
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& filters_shape,
|
||||
const ngraph::Strides& conv_stride,
|
||||
const ngraph::CoordinateDiff& pads_begin,
|
||||
const ngraph::CoordinateDiff& pads_end,
|
||||
const ngraph::Strides& conv_dilation,
|
||||
const ngraph::Shape& bias_shape,
|
||||
const ngraph::op::PadType& pad_type,
|
||||
const std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
|
||||
const std::shared_ptr<ngraph::opset7::Constant>& bias_const,
|
||||
const std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias);
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
modelType model;
|
||||
};
|
||||
|
||||
void ConvertDWSCToScaleShiftsTestFixture::SetUp() {
|
||||
bool fq;
|
||||
DWSCToScaleShiftsParams params;
|
||||
ngraph::Shape input_shape;
|
||||
ngraph::Shape filters_shape, bias_shape;
|
||||
ngraph::Strides conv_stride, conv_dilation;
|
||||
ngraph::CoordinateDiff pads_begin, pads_end;
|
||||
ngraph::op::PadType pad_type;
|
||||
std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc;
|
||||
std::shared_ptr<ngraph::opset7::Constant> bias_const;
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias;
|
||||
std::tie(fq, params) = this->GetParam();
|
||||
std::tie(model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type) = params;
|
||||
|
||||
function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type, dwsc, bias_const, fq_bias);
|
||||
reference_function = get_reference(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
|
||||
bias_shape, pad_type, dwsc, bias_const, fq_bias);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::Node> input, size_t offset, size_t size) {
|
||||
return std::make_shared<ngraph::opset7::StridedSlice>(
|
||||
input, // data
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)0, offset}), // begin sice index
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)0, offset + size}), // end slice index
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)1, (size_t)1}), // strides
|
||||
std::vector<int64_t>{1, 0}, // begin mask
|
||||
std::vector<int64_t>{1, 0}); // end mask
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer,
|
||||
std::shared_ptr<ngraph::Node> last_node) {
|
||||
if (fq_layer != nullptr) {
|
||||
return fq_layer->clone_with_new_inputs({last_node,
|
||||
fq_layer->input_value(1), fq_layer->input_value(2),
|
||||
fq_layer->input_value(3), fq_layer->input_value(4)});
|
||||
}
|
||||
return last_node;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> DecomposeDWSC(std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc,
|
||||
std::shared_ptr<ngraph::opset7::Constant> bias_const, std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias,
|
||||
std::shared_ptr<ngraph::opset7::Reshape> flat_input_plane, std::shared_ptr<ngraph::Node> flat_filters_plane) {
|
||||
std::shared_ptr<ngraph::opset7::Constant> const_zero_padding;
|
||||
std::shared_ptr<ngraph::Node> reshaped_bias;
|
||||
ngraph::OutputVector output_chunks;
|
||||
auto input_channel_count = dwsc->get_input_shape(0)[1];
|
||||
auto input_width = dwsc->get_input_shape(0)[3];
|
||||
auto output_width = dwsc->get_output_shape(0)[3];
|
||||
auto filter_width = dwsc->get_input_shape(1)[4];
|
||||
auto pads_begin = dwsc->get_pads_begin()[1];
|
||||
auto stride_width = dwsc->get_strides()[1];
|
||||
auto dilation_width = dwsc->get_dilations()[1];
|
||||
|
||||
// Constant with zero padding
|
||||
if (pads_begin) {
|
||||
const_zero_padding = std::make_shared<ngraph::opset7::Constant>(dwsc->get_element_type(), ngraph::Shape{1, input_channel_count}, 0);
|
||||
}
|
||||
|
||||
// Reshape bias const
|
||||
if (bias_const) {
|
||||
auto bias_size = shape_size(bias_const->get_shape());
|
||||
reshaped_bias = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(bias_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, bias_size}), false);
|
||||
}
|
||||
|
||||
// Move filter over input performing multiplication and addition (scaleshift), take padding, stride, dilation and bias into account
|
||||
for (int32_t input_position = -pads_begin, o = 0; o < output_width; input_position += stride_width, o++) {
|
||||
std::shared_ptr<ngraph::Node> previous_layer_output, last_layer_output;
|
||||
int32_t filter_end = input_position + filter_width * dilation_width;
|
||||
bool first = true;
|
||||
|
||||
filter_end = filter_end < input_width ? filter_end : input_width;
|
||||
|
||||
for (int32_t filter_pos = input_position, filter_idx = 0; filter_pos < filter_end; filter_pos += dilation_width, filter_idx++) {
|
||||
if (filter_pos >= 0) {
|
||||
auto conv_input_slice = FlatCrop(flat_input_plane, filter_pos * input_channel_count, input_channel_count);
|
||||
auto conv_filter_slice = FlatCrop(flat_filters_plane, filter_idx * input_channel_count, input_channel_count);
|
||||
|
||||
if (first) {
|
||||
first = false;
|
||||
previous_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
|
||||
if (bias_const) {
|
||||
previous_layer_output = std::make_shared<ngraph::opset7::Add>(previous_layer_output, reshaped_bias);
|
||||
previous_layer_output = InsertFQLayer(fq_bias, previous_layer_output);
|
||||
}
|
||||
last_layer_output = previous_layer_output;
|
||||
} else {
|
||||
last_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
|
||||
last_layer_output = std::make_shared<ngraph::opset7::Add>(last_layer_output, previous_layer_output);
|
||||
previous_layer_output = last_layer_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last_layer_output) {
|
||||
IE_ASSERT(const_zero_padding);
|
||||
last_layer_output = const_zero_padding;
|
||||
}
|
||||
|
||||
output_chunks.push_back(last_layer_output);
|
||||
}
|
||||
|
||||
// Concat and transpose is only needed when output width > 1
|
||||
if (output_chunks.size() > 1) {
|
||||
return std::make_shared<ngraph::opset7::Concat>(output_chunks, 0);
|
||||
}
|
||||
|
||||
return output_chunks[0].get_node_shared_ptr();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConvertDWSCToScaleShiftsTestFixture::get_reference(const bool& fq,
|
||||
const modelType& model,
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& filters_shape,
|
||||
const ngraph::Strides& conv_stride,
|
||||
const ngraph::CoordinateDiff& pads_begin,
|
||||
const ngraph::CoordinateDiff& pads_end,
|
||||
const ngraph::Strides& conv_dilation,
|
||||
const ngraph::Shape& bias_shape,
|
||||
const ngraph::op::PadType& pad_type,
|
||||
const std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
|
||||
const std::shared_ptr<ngraph::opset7::Constant>& bias_const,
|
||||
const std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
auto output_channel_count = dwsc->get_output_shape(0)[1];
|
||||
auto output_width = dwsc->get_output_shape(0)[3];
|
||||
|
||||
// Prepare flat input data
|
||||
auto flat_input_plane = std::make_shared<ngraph::opset7::Reshape>(input_params,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
ngraph::Shape{1, ngraph::shape_size(input_shape)}), false);
|
||||
|
||||
// Prepare flat filter data
|
||||
auto filters_const = std::dynamic_pointer_cast<ngraph::Node>(dwsc->get_input_node_shared_ptr(1));
|
||||
auto filters_size = ngraph::shape_size(filters_const->get_shape());
|
||||
|
||||
auto transposed_filters_const = ngraph::op::util::make_try_fold<ngraph::opset7::Transpose>(filters_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5}, ngraph::Shape{4, 1, 2, 3, 0}));
|
||||
|
||||
auto flat_filters_plane = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(transposed_filters_const,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, filters_size}), false);
|
||||
|
||||
// Convert DWSC to a set of diagonal layers
|
||||
auto output_plane = DecomposeDWSC(dwsc, bias_const, fq_bias, flat_input_plane, flat_filters_plane);
|
||||
|
||||
// Restore the original output shape
|
||||
auto result = std::make_shared<ngraph::opset7::Reshape>(output_plane,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||
ngraph::Shape{1, output_channel_count, 1, output_width}), false);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{std::make_shared<ngraph::opset7::Result>(result)}, ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
void execute_test(modelType model, std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
|
||||
manager.register_pass<GNAPluginNS::ConvertDWSCToScaleShifts>();
|
||||
manager.run_passes(function);
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST_P(ConvertDWSCToScaleShiftsTestFixture, CompareFunctions) {
|
||||
execute_test(model, function, reference_function);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertDWSCToScaleShiftsTestSuite, ConvertDWSCToScaleShiftsTestFixture,
|
||||
::testing::Combine(
|
||||
// With / without Fake Quantize layers
|
||||
::testing::Values(true, false),
|
||||
::testing::Values(
|
||||
std::make_tuple(modelType::TranspDWSCTransp, ngraph::Shape{1, 1, 5, 32}, ngraph::Shape{1, 3}, ngraph::Strides{1, 1},
|
||||
ngraph::CoordinateDiff{0, 1}, ngraph::CoordinateDiff{0, 1}, ngraph::Strides{1, 1},
|
||||
ngraph::Shape{1, 32, 1, 1}, ngraph::op::PadType::VALID),
|
||||
std::make_tuple(modelType::TranspDWSCBiasTransp, ngraph::Shape{1, 1, 5, 32}, ngraph::Shape{1, 3}, ngraph::Strides{1, 1},
|
||||
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 2}, ngraph::Strides{1, 1},
|
||||
ngraph::Shape{1, 32, 1, 1}, ngraph::op::PadType::VALID))));
|
||||
|
||||
TEST_P(ConvertDWSCToScaleShiftsTestInvalidFixture, CompareFunctions) {
|
||||
execute_test(model, function, reference_function);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertDWSCToScaleShiftsInvalidTestSuite, ConvertDWSCToScaleShiftsTestInvalidFixture,
|
||||
::testing::Combine(
|
||||
// With / without Fake Quantize layers
|
||||
::testing::Values(true, false),
|
||||
::testing::Values(
|
||||
std::make_tuple(modelType::TranspDWSCTransp, ngraph::Shape{2, 16, 8, 1}, ngraph::Shape{1, 2}, ngraph::Strides{1, 1},
|
||||
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
|
||||
ngraph::Shape{1, 4, 1, 1}, ngraph::op::PadType::SAME_UPPER),
|
||||
std::make_tuple(modelType::TranspDWSCBiasTransp, ngraph::Shape{2, 16, 8, 1}, ngraph::Shape{1, 2}, ngraph::Strides{1, 1},
|
||||
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
|
||||
ngraph::Shape{1, 4, 1, 1}, ngraph::op::PadType::EXPLICIT))));
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace testing
|
@ -6,7 +6,7 @@
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "transformations/convert_padded2valid_conv.hpp"
|
||||
#include "transformations/convert_padded_to_valid_convolution.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
@ -39,12 +39,12 @@ typedef std::tuple<
|
||||
ngraph::Strides, // Max Pool stride
|
||||
ngraph::Shape, // Max Pool shape
|
||||
ngraph::op::PadType // Padding type
|
||||
> padded2ValidConvParams;
|
||||
> paddedToValidConvParams;
|
||||
|
||||
typedef std::tuple<
|
||||
bool, // With / without Fake Quantize layers
|
||||
padded2ValidConvParams // Test parameters
|
||||
> fqPadded2ValidConvParams;
|
||||
paddedToValidConvParams // Test parameters
|
||||
> fqPaddedToValidConvParams;
|
||||
|
||||
struct ConvData {
|
||||
size_t input_height;
|
||||
@ -193,17 +193,17 @@ std::shared_ptr<ngraph::Function> get_initial_function(const bool& fq,
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ConvertPadded2ValidConvTestInvalidFixture : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqPadded2ValidConvParams> {
|
||||
class ConvertPaddedToValidConvTestInvalidFixture : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqPaddedToValidConvParams> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
void ConvertPadded2ValidConvTestInvalidFixture::SetUp() {
|
||||
void ConvertPaddedToValidConvTestInvalidFixture::SetUp() {
|
||||
bool fq;
|
||||
padded2ValidConvParams params;
|
||||
paddedToValidConvParams params;
|
||||
modelType model;
|
||||
ngraph::PartialShape input_shape;
|
||||
ngraph::Shape filters_shape, bias_shape, maxpool_shape;
|
||||
@ -223,8 +223,8 @@ void ConvertPadded2ValidConvTestInvalidFixture::SetUp() {
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ConvertPadded2ValidConvTestFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqPadded2ValidConvParams> {
|
||||
class ConvertPaddedToValidConvTestFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<fqPaddedToValidConvParams> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
|
||||
@ -244,9 +244,9 @@ public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
void ConvertPadded2ValidConvTestFixture::SetUp() {
|
||||
void ConvertPaddedToValidConvTestFixture::SetUp() {
|
||||
bool fq;
|
||||
padded2ValidConvParams params;
|
||||
paddedToValidConvParams params;
|
||||
modelType model;
|
||||
ngraph::PartialShape input_shape;
|
||||
ngraph::Shape filters_shape, bias_shape, maxpool_shape;
|
||||
@ -354,7 +354,7 @@ std::shared_ptr<ngraph::Node> CreatePaddedNet(const ngraph::Output<ngraph::Node>
|
||||
return padded_input_plane;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConvertPadded2ValidConvTestFixture::get_reference(const bool& fq,
|
||||
std::shared_ptr<ngraph::Function> ConvertPaddedToValidConvTestFixture::get_reference(const bool& fq,
|
||||
const modelType& model,
|
||||
const ngraph::PartialShape& input_shape,
|
||||
const ngraph::Shape& filters_shape,
|
||||
@ -406,18 +406,18 @@ std::shared_ptr<ngraph::Function> ConvertPadded2ValidConvTestFixture::get_refere
|
||||
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<GNAPluginNS::ConvertPadded2ValidConv>();
|
||||
manager.register_pass<GNAPluginNS::ConvertPaddedToValidConv>();
|
||||
manager.run_passes(function);
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST_P(ConvertPadded2ValidConvTestFixture, CompareFunctions) {
|
||||
TEST_P(ConvertPaddedToValidConvTestFixture, CompareFunctions) {
|
||||
execute_test(function, reference_function);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvTestSuite, ConvertPadded2ValidConvTestFixture,
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertPaddedToValidConvTestSuite, ConvertPaddedToValidConvTestFixture,
|
||||
::testing::Combine(
|
||||
// With / without Fake Quantize layers
|
||||
::testing::Values(true, false),
|
||||
@ -444,11 +444,11 @@ INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvTestSuite, ConvertPadded2ValidCo
|
||||
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
|
||||
ngraph::Shape{1, 1, 1, 4}, ngraph::Strides{1, 1}, ngraph::Shape{1, 2}, ngraph::op::PadType::EXPLICIT))));
|
||||
|
||||
TEST_P(ConvertPadded2ValidConvTestInvalidFixture, CompareFunctions) {
|
||||
TEST_P(ConvertPaddedToValidConvTestInvalidFixture, CompareFunctions) {
|
||||
execute_test(function, reference_function);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvInvalidTestSuite, ConvertPadded2ValidConvTestInvalidFixture,
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertPaddedToValidConvInvalidTestSuite, ConvertPaddedToValidConvTestInvalidFixture,
|
||||
::testing::Combine(
|
||||
// With / without Fake Quantize layers
|
||||
::testing::Values(true, false),
|
@ -6,7 +6,7 @@
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "transformations/decompose_2d_conv.hpp"
|
||||
#include "transformations/decompose_2d_convolution.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
@ -426,7 +426,7 @@ void TransformInput(const GraphData& graph_data, const ConvParams& conv_params,
|
||||
*/
|
||||
|
||||
// First we need to prepare flat (height = 1) slices of input data proper for flattened (height = 1) filters created later on;
|
||||
// the input datat is overlapping (duplicated)
|
||||
// the input data is overlapping (duplicated)
|
||||
ngraph::OutputVector dilated_input_planes;
|
||||
for (size_t filter_height = 0; filter_height < conv_params.filter_height; filter_height++) {
|
||||
size_t offset;
|
||||
@ -704,10 +704,13 @@ void execute_test(modelType model, std::shared_ptr<ngraph::Function> function, s
|
||||
case modelType::TranspConvBcastAddActTransp:
|
||||
case modelType::TranspConvBcastAddMaxPoolActTransp:
|
||||
manager.register_pass<GNAPluginNS::Decompose2DConv>();
|
||||
break;
|
||||
case modelType::TranspConvTranspBcastAdd:
|
||||
manager.register_pass<GNAPluginNS::Decompose2DConvTransposedWithBias>();
|
||||
break;
|
||||
case modelType::TranspConvTranspBcastAddAct:
|
||||
manager.register_pass<GNAPluginNS::Decompose2DConvTransposedWithBiasAF>();
|
||||
break;
|
||||
}
|
||||
|
||||
manager.run_passes(function);
|
Loading…
Reference in New Issue
Block a user