[GNA] Depth-wise separable convolution support (#7281)

* [GNA] Add support for DWSC, other fixes and code refactoring.

* [GNA] Change supported layout to NHWC

* [GNA] Detect bias const only on second position, move verification of dwsc to matcher
This commit is contained in:
Szymon Irzabek 2021-09-16 10:54:51 +02:00 committed by GitHub
parent 10f0075e90
commit 57b51701fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 947 additions and 118 deletions

View File

@ -64,9 +64,10 @@
#include "transformations/convert_matmul_to_pointwise_convolution.hpp"
#include "transformations/split_convolution_with_large_buffer_size.hpp"
#include "transformations/handle_transposes_around_matmul.hpp"
#include "transformations/decompose_2d_conv.hpp"
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/decompose_2d_convolution.hpp"
#include "transformations/convert_padded_to_valid_convolution.hpp"
#include "transformations/insert_reshape_around_matmul.hpp"
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
#include "transformations/remove_single_input_concat.hpp"
@ -716,7 +717,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<ngraph::pass::ConvertPriorBox>();
manager.register_pass<ngraph::pass::CommonOptimizations>();
manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
manager.register_pass<ConvertPadded2ValidConv>();
manager.register_pass<ConvertDWSCToScaleShifts>();
manager.register_pass<ConvertPaddedToValidConv>();
if (config.gnaCompileTarget == InferenceEngine::GNAConfigParams::GNA_TARGET_2_0) {
manager.register_pass<Decompose2DConvTransposedWithBiasAF>();
manager.register_pass<Decompose2DConvTransposedWithBias>();
@ -748,7 +750,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<RemoveExtraReshapes>();
// UnrollTI should be the last transformation in the transformation pipeline
manager.register_pass<ngraph::pass::UnrollTensorIterator>();
const auto& pass_config = manager.get_pass_config();
pass_config->disable<ngraph::pass::FakeQuantizeMulFusion>();
pass_config->disable<ngraph::pass::FakeQuantizeReshapeFusion>();

View File

@ -0,0 +1,207 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
#include <ie_common.h>
#include "utils/transformation_helper.hpp"
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(ConvertDWSCToScaleShifts, "ConvertDWSCToScaleShifts", 0);
static std::shared_ptr<ngraph::Node> DecomposeDWSC(std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc,
std::shared_ptr<ngraph::opset7::Constant> bias_const, std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias,
std::shared_ptr<ngraph::opset7::Reshape> flat_input_plane, std::shared_ptr<ngraph::Node> flat_filters_plane) {
std::shared_ptr<ngraph::opset7::Constant> const_zero_padding;
std::shared_ptr<ngraph::Node> reshaped_bias;
ngraph::OutputVector output_chunks;
auto input_channel_count = dwsc->get_input_shape(0)[1];
auto input_width = dwsc->get_input_shape(0)[3];
auto output_width = dwsc->get_output_shape(0)[3];
auto filter_width = dwsc->get_input_shape(1)[4];
auto pads_begin = dwsc->get_pads_begin()[1];
auto stride_width = dwsc->get_strides()[1];
auto dilation_width = dwsc->get_dilations()[1];
// Constant with zero padding
if (pads_begin) {
const_zero_padding = std::make_shared<ngraph::opset7::Constant>(dwsc->get_element_type(), ngraph::Shape{1, input_channel_count}, 0);
copy_runtime_info(dwsc, const_zero_padding);
}
// Reshape bias const
if (bias_const) {
auto bias_size = shape_size(bias_const->get_shape());
reshaped_bias = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(bias_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, bias_size}), false);
}
// Move filter over input performing multiplication and addition (scaleshift), take padding, stride, dilation and bias into account
for (int32_t input_position = -pads_begin, o = 0; o < output_width; input_position += stride_width, o++) {
std::shared_ptr<ngraph::Node> previous_layer_output, last_layer_output;
int32_t filter_end = input_position + filter_width * dilation_width;
bool first = true;
filter_end = filter_end < input_width ? filter_end : input_width;
for (int32_t filter_pos = input_position, filter_idx = 0; filter_pos < filter_end; filter_pos += dilation_width, filter_idx++) {
if (filter_pos >= 0) {
auto conv_input_slice = FlatCrop(flat_input_plane, filter_pos * input_channel_count, input_channel_count);
auto conv_filter_slice = FlatCrop(flat_filters_plane, filter_idx * input_channel_count, input_channel_count);
if (first) {
first = false;
previous_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
copy_runtime_info(dwsc, previous_layer_output);
if (bias_const) {
previous_layer_output = std::make_shared<ngraph::opset7::Add>(previous_layer_output, reshaped_bias);
copy_runtime_info(dwsc, previous_layer_output);
previous_layer_output = InsertFQLayer(fq_bias, previous_layer_output);
}
last_layer_output = previous_layer_output;
} else {
last_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
copy_runtime_info(dwsc, last_layer_output);
last_layer_output = std::make_shared<ngraph::opset7::Add>(last_layer_output, previous_layer_output);
copy_runtime_info(dwsc, last_layer_output);
previous_layer_output = last_layer_output;
}
}
}
if (!last_layer_output) {
IE_ASSERT(const_zero_padding);
last_layer_output = const_zero_padding;
}
output_chunks.push_back(last_layer_output);
}
// Concat is only needed when output width > 1
if (output_chunks.size() > 1) {
auto concat_output_plane = std::make_shared<ngraph::opset7::Concat>(output_chunks, 0);
copy_runtime_info(dwsc, concat_output_plane);
return concat_output_plane;
}
return output_chunks[0].get_node_shared_ptr();
}
static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
std::shared_ptr<ngraph::Node> dwsc_node,
std::shared_ptr<ngraph::Node> bias_const_node,
std::shared_ptr<ngraph::Node> fq_bias_node,
std::shared_ptr<ngraph::Node> trailing_transpose) {
auto dwsc = std::dynamic_pointer_cast<ngraph::opset7::GroupConvolution>(dwsc_node);
auto bias_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias_const_node);
auto fq_bias = std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(fq_bias_node);
// We are looking for Transpose(NHWC->NCHW) => GroupConv => Transpose(NCHW->NHWC)
// or similar cases, so required network must be in NHWC order like in TF
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(leading_transpose), {0, 3, 1, 2}))
return false;
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
return false;
auto output_channel_count = dwsc->get_output_shape(0)[1];
auto output_width = dwsc->get_output_shape(0)[3];
// Prepare flat input data
auto flat_input_plane = std::make_shared<ngraph::opset7::Reshape>(leading_transpose->input_value(0),
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
ngraph::Shape{1, shape_size(dwsc->input_value(0).get_shape())}), false);
// Prepare flat filter data
auto filters_const = std::dynamic_pointer_cast<ngraph::Node>(dwsc->get_input_node_shared_ptr(1));
auto filters_size = shape_size(filters_const->get_shape());
auto transposed_filters_const = ngraph::op::util::make_try_fold<ngraph::opset7::Transpose>(filters_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5}, ngraph::Shape{4, 1, 2, 3, 0}));
auto flat_filters_plane = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(transposed_filters_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, filters_size}), false);
copy_runtime_info(dwsc, {flat_input_plane, transposed_filters_const, flat_filters_plane});
// Convert DWSC to a set of diagonal layers
auto output_plane = DecomposeDWSC(dwsc, bias_const, fq_bias, flat_input_plane, flat_filters_plane);
// Restore the original output shape
auto result = std::make_shared<ngraph::opset7::Reshape>(output_plane,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
ngraph::Shape{1, output_channel_count, 1, output_width}), false);
copy_runtime_info(dwsc, result);
// We need to put here the original Group Convolution layer name, so the new layer output can be used as a network result
std::string result_name = trailing_transpose->get_friendly_name();
replace_node(trailing_transpose, result);
result->set_friendly_name(result_name);
return true;
}
static bool VerifyDWSC(const ngraph::Output<ngraph::Node>& output) {
auto dwsc = output.get_node();
// Verify it's a 1D convolution
// Verify that filter group count == input channel count
// Verify that per group filter output channel count == 1
if (!consumers_and_rank(1, 4)(output) ||
dwsc->get_input_shape(1)[3] != 1 || dwsc->get_input_shape(0)[2] != 1 || dwsc->get_output_shape(0)[2] != 1 ||
dwsc->get_input_shape(1)[0] != dwsc->get_input_shape(0)[1] ||
dwsc->get_input_shape(1)[1] != 1)
return false;
return true;
}
ConvertDWSCToScaleShifts::ConvertDWSCToScaleShifts() {
MATCHER_SCOPE(ConvertDWSCToScaleShifts);
auto const_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto leading_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({ngraph::pattern::any_input(), const_input},
consumers_and_rank(1, 4));
auto filters_const_fq = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(ngraph::pattern::rank_equals(4));
auto fq_filters_const = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({filters_const_fq, const_input, const_input, const_input, const_input},
consumers_and_rank(1, 4));
auto reshape_filters_const = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({fq_filters_const, const_input}, ngraph::pattern::rank_equals(5));
auto filters_const = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(ngraph::pattern::rank_equals(5));
auto dwsc_filters = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{filters_const, reshape_filters_const });
auto dwsc = ngraph::pattern::wrap_type<ngraph::opset7::GroupConvolution>({leading_transpose, dwsc_filters}, VerifyDWSC);
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Add>({dwsc, const_input});
auto fq_bias = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({bias, const_input, const_input, const_input, const_input},
consumers_and_rank(1, 4));
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{dwsc, bias, fq_bias});
auto trailing_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, const_input}, consumers_and_rank(1, 4));
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
std::shared_ptr<ngraph::Node> bias_const = nullptr;
if (bias_node && (bias_const = VerifyBiasGetConst(pattern_map.at(dwsc).get_node_shared_ptr(), bias_node)) == nullptr)
return false;
auto fq_bias_it = pattern_map.find(fq_bias);
auto fq_bias_node = (fq_bias_it == std::end(pattern_map) ? nullptr : fq_bias_it->second.get_node_shared_ptr());
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), pattern_map.at(dwsc).get_node_shared_ptr(),
bias_const, fq_bias_node,
pattern_map.at(trailing_transpose).get_node_shared_ptr());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(trailing_transpose, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,21 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
/**
* @brief Convert a depthwise separable convolution (represented by a GroupConvolution) to a set of ScaleShift layers (MatMul + Add)
* Additionally supported are bias and fake quantize layers.
*/
class ConvertDWSCToScaleShifts : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertDWSCToScaleShifts();
};
} // namespace GNAPluginNS

View File

@ -4,7 +4,7 @@
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/convert_padded_to_valid_convolution.hpp"
#include <memory>
@ -19,7 +19,7 @@
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(ConvertPadded2ValidConv, "ConvertPadded2ValidConv", 0);
NGRAPH_RTTI_DEFINITION(ConvertPaddedToValidConv, "ConvertPaddedToValidConv", 0);
static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> conv, ConvData& conv_data) {
const auto& input = conv->input_value(0);
@ -34,17 +34,6 @@ static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> co
return conv_data.pads_begin_height || conv_data.pads_end_height || conv_data.pads_begin_width || conv_data.pads_end_width;
}
static bool VerifyBias(std::shared_ptr<ngraph::opset7::Add> bias, const size_t& filter_count) {
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(0).get_node_shared_ptr());
// We need to check both inputs of Add when looking for constant
if (!add_const)
add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
// The add may be a normal add not convolution bias, then we just go further
return (add_const && shape_size(add_const->get_shape()) == filter_count);
}
static void InsertPadding(ngraph::OutputVector& input_rows_to_concat, size_t size, const std::shared_ptr<ngraph::opset7::Convolution>& conv,
const std::shared_ptr<ngraph::opset7::Constant> padding_const, size_t biggest_padding) {
@ -181,9 +170,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
return false;
if (bias && !VerifyBias(std::dynamic_pointer_cast<ngraph::opset7::Add>(bias), conv_data.filter_count))
return false;
GeneratePadding(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(leading_transpose),
std::dynamic_pointer_cast<ngraph::opset7::Convolution>(conv), conv_data);
@ -196,8 +182,8 @@ static std::function<bool(ngraph::Output<ngraph::Node>)> consumers_and_rank(cons
};
}
ConvertPadded2ValidConv::ConvertPadded2ValidConv() {
MATCHER_SCOPE(ConvertPadded2ValidConv);
ConvertPaddedToValidConv::ConvertPaddedToValidConv() {
MATCHER_SCOPE(ConvertPaddedToValidConv);
auto const_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto leading_transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({ngraph::pattern::any_input(), const_input},
@ -237,6 +223,9 @@ ConvertPadded2ValidConv::ConvertPadded2ValidConv() {
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
if (bias_node && !VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), bias_node))
return false;
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), pattern_map.at(conv).get_node_shared_ptr(),
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node);
};

View File

@ -28,10 +28,10 @@ namespace GNAPluginNS {
* Transpose (NCHW -> NHWC) Transpose (NCHW -> NHWC)
*
*/
class ConvertPadded2ValidConv : public ngraph::pass::MatcherPass {
class ConvertPaddedToValidConv : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertPadded2ValidConv();
ConvertPaddedToValidConv();
};
} // namespace GNAPluginNS

View File

@ -4,9 +4,7 @@
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/decompose_2d_conv.hpp"
#include <memory>
#include "transformations/decompose_2d_convolution.hpp"
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@ -68,22 +66,6 @@ static bool VerifyAndGetConvData(std::shared_ptr<ngraph::opset7::Convolution> co
return true;
}
static std::shared_ptr<ngraph::Node> VerifyBiasAndReshapeConst(std::shared_ptr<ngraph::opset7::Add> conv_bias, const ConvData& conv_data) {
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(conv_bias->input_value(1).get_node_shared_ptr());
if (add_const) {
auto bias_size = shape_size(add_const->get_shape());
// The add may be a normal add not conv bias, then we just go further
if (bias_size == conv_data.filter_count) {
return ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(add_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, bias_size, 1, 1}), false);
}
}
// Bias size does not match (or dynamic bias), can't decompose such convolution
return nullptr;
}
static bool VerifyMaxPool(GraphData& graph_data, std::shared_ptr<ngraph::opset7::MaxPool> max_pool) {
auto pool_filter = max_pool->get_kernel();
auto pool_strides = max_pool->get_strides();
@ -236,7 +218,7 @@ static void TransformInput(const GraphData& graph_data, const ConvData& conv_dat
*/
// First we need to prepare flat (height = 1) slices of input data proper for flattened (height = 1) filters created later on;
// the input datat is overlapping (duplicated)
// the input data is overlapping (duplicated)
ngraph::OutputVector dilated_input_planes;
for (size_t filter_height = 0; filter_height < conv_data.filter_height; filter_height++) {
size_t offset;
@ -280,16 +262,6 @@ static void TransformInput(const GraphData& graph_data, const ConvData& conv_dat
split_input_plane = flattened_dilated_transposed_input;
}
static void InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fqLayer,
std::shared_ptr<ngraph::Node> lastNode) {
if (fqLayer != nullptr) {
lastNode = fqLayer->clone_with_new_inputs({lastNode,
fqLayer->input_value(1), fqLayer->input_value(2),
fqLayer->input_value(3), fqLayer->input_value(4)});
ngraph::copy_runtime_info(fqLayer, lastNode);
}
}
// Valid 1D (decomposed 2D) convolution wrapped with transposes NHWC => NCHW => conv => NCHW => NHWC
static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, const ConvData& conv_data, const ngraph::Output<ngraph::Node>& input,
std::shared_ptr<ngraph::Node> filters, const size_t conv_index, const size_t h_index) {
@ -298,7 +270,7 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 3, 1, 2})->output(0));
// Fake quantize
InsertFQLayer(graph_data.fq_conv, filters);
filters = InsertFQLayer(graph_data.fq_conv, filters);
// 1D Convolution
auto conv = std::make_shared<ngraph::opset7::Convolution>(nchw_input, filters,
@ -306,13 +278,16 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
std::string conv_name = graph_data.conv->get_friendly_name() + "_H_" + std::to_string(h_index) + "_CH_" + std::to_string(0);
conv->set_friendly_name(conv_name);
std::shared_ptr<ngraph::Node> last_conv_block_op = conv;
// Bias & fake quantize
std::shared_ptr<ngraph::Node> last_conv_block_op = conv;
if (graph_data.bias_const && conv_index == 0) {
last_conv_block_op = std::make_shared<ngraph::opset7::Add>(conv, graph_data.bias_const);
auto bias_size = shape_size(graph_data.bias_const->get_shape());
auto reshaped_bias_const = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(graph_data.bias_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, bias_size, 1, 1}), false);
last_conv_block_op = std::make_shared<ngraph::opset7::Add>(conv, reshaped_bias_const);
copy_runtime_info(graph_data.conv, last_conv_block_op);
InsertFQLayer(graph_data.fq_bias, last_conv_block_op);
last_conv_block_op = InsertFQLayer(graph_data.fq_bias, last_conv_block_op);
}
// Max pooling
@ -326,7 +301,7 @@ static std::shared_ptr<ngraph::Node> Create1DConv(const GraphData& graph_data, c
if (graph_data.af && graph_data.conv_count == 1) {
last_conv_block_op = graph_data.af->copy_with_new_inputs({last_conv_block_op});
copy_runtime_info(conv, last_conv_block_op);
InsertFQLayer(graph_data.fq_af, last_conv_block_op);
last_conv_block_op = InsertFQLayer(graph_data.fq_af, last_conv_block_op);
}
// Transpose NCHW => NHWC
@ -472,6 +447,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
std::shared_ptr<ngraph::Node> conv,
std::shared_ptr<ngraph::Node> trailing_transpose,
std::shared_ptr<ngraph::Node> bias,
std::shared_ptr<ngraph::Node> bias_const,
std::shared_ptr<ngraph::Node> fq_bias,
std::shared_ptr<ngraph::Node> max_pool,
std::shared_ptr<ngraph::Node> af,
@ -486,7 +462,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
std::dynamic_pointer_cast<ngraph::opset7::MaxPool>(max_pool),
std::dynamic_pointer_cast<ngraph::op::util::UnaryElementwiseArithmetic>(af),
std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(fq_af),
last_op_for_replacement, nullptr, 1, 1, 1};
last_op_for_replacement, bias_const, 1, 1, 1};
ConvData conv_data;
if (!VerifyAndGetConvData(std::dynamic_pointer_cast<ngraph::opset7::Convolution>(conv), conv_data))
@ -500,9 +476,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
if (!TransposeOrderMatches(std::dynamic_pointer_cast<ngraph::opset7::Transpose>(trailing_transpose), {0, 2, 3, 1}))
return false;
if (bias && !(graph_data.bias_const = VerifyBiasAndReshapeConst(std::dynamic_pointer_cast<ngraph::opset7::Add>(bias), conv_data)))
return false;
if (max_pool && !VerifyMaxPool(graph_data, std::dynamic_pointer_cast<ngraph::opset7::MaxPool>(max_pool)))
return false;
@ -515,22 +488,6 @@ static bool Convert(std::shared_ptr<ngraph::Node> leading_transpose,
return true;
}
static bool VerifyBias(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias) {
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
if (!add_const) {
add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(0).get_node_shared_ptr());
}
if (!add_const) {
auto bias_size = shape_size(add_const->get_shape());
auto conv_filter_count = conv->input_value(1).get_shape()[0];
if (bias_size == conv_filter_count)
return true;
}
return false;
}
Decompose2DConv::Decompose2DConv() {
MATCHER_SCOPE(Decompose2DConv);
@ -576,6 +533,11 @@ Decompose2DConv::Decompose2DConv() {
auto fq_conv_node = (fq_conv_it == std::end(pattern_map) ? nullptr : fq_conv_it->second.get_node_shared_ptr());
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
if (bias_node && !(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), bias_node)))
return false;
auto fq_bias_it = pattern_map.find(fq_bias);
auto fq_bias_node = (fq_bias_it == std::end(pattern_map) ? nullptr : fq_bias_it->second.get_node_shared_ptr());
auto fq_af_it = pattern_map.find(fq_af);
@ -596,7 +558,7 @@ Decompose2DConv::Decompose2DConv() {
}
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), fq_conv_node, pattern_map.at(conv).get_node_shared_ptr(),
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node, fq_bias_node, max_pool_node, af_node, fq_af_node,
pattern_map.at(trailing_transpose).get_node_shared_ptr(), bias_node, bias_const_node, fq_bias_node, max_pool_node, af_node, fq_af_node,
pattern_map.at(trailing_transpose).get_node_shared_ptr());
};
@ -621,11 +583,13 @@ Decompose2DConvTransposedWithBias::Decompose2DConvTransposedWithBias() {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (!VerifyBias(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr()))
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
if (!(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr())))
return false;
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), nullptr, pattern_map.at(conv).get_node_shared_ptr(),
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), nullptr, nullptr,
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), bias_const_node, nullptr, nullptr,
nullptr, nullptr, pattern_map.at(bias).get_node_shared_ptr());
};
@ -654,11 +618,13 @@ Decompose2DConvTransposedWithBiasAF::Decompose2DConvTransposedWithBiasAF() {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (!VerifyBias(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr()))
std::shared_ptr<ngraph::Node> bias_const_node = nullptr;
if (!(bias_const_node = VerifyBiasGetConst(pattern_map.at(conv).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr())))
return false;
return Convert(pattern_map.at(leading_transpose).get_node_shared_ptr(), nullptr, pattern_map.at(conv).get_node_shared_ptr(),
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), nullptr,
pattern_map.at(trailing_transpose).get_node_shared_ptr(), pattern_map.at(bias).get_node_shared_ptr(), bias_const_node, nullptr,
nullptr, pattern_map.at(af).get_node_shared_ptr(), nullptr, pattern_map.at(af).get_node_shared_ptr());
};

View File

@ -5,6 +5,7 @@
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "transformation_helper.hpp"
@ -72,4 +73,29 @@ std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::No
std::vector<int64_t>{1, 0}); // end mask
}
std::shared_ptr<ngraph::Node> VerifyBiasGetConst(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias) {
auto add_const = std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias->input_value(1).get_node_shared_ptr());
// Check if it's really a bias and not just addition
if (add_const) {
auto bias_size = shape_size(add_const->get_shape());
auto conv_filter_count = conv->get_output_shape(0)[1];
if (bias_size == conv_filter_count)
return add_const;
}
return nullptr;
}
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer,
std::shared_ptr<ngraph::Node> last_node) {
if (fq_layer != nullptr) {
auto new_fq = fq_layer->clone_with_new_inputs({last_node,
fq_layer->input_value(1), fq_layer->input_value(2),
fq_layer->input_value(3), fq_layer->input_value(4)});
ngraph::copy_runtime_info(new_fq, fq_layer);
return new_fq;
}
return last_node;
}
} // namespace GNAPluginNS

View File

@ -61,4 +61,21 @@ bool TransposeOrderMatches(std::shared_ptr<ngraph::opset7::Transpose> transpose,
* @return pointer to the newly created slice
*/
std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::Node> input, size_t offset, size_t size);
/**
* @brief checks whether an add present after convolution is a bias and gets its const input
* @param conv convolution layer preceding potential bias
* @param bias potential bias layer passed from ngraph matcher
* @return bias const if the add layer present after convolution is a bias, nullptr otherwise
*/
std::shared_ptr<ngraph::Node> VerifyBiasGetConst(std::shared_ptr<ngraph::Node> conv, std::shared_ptr<ngraph::Node> bias);
/**
* @brief inserts a new fake quantize layer (if it exists) copied from an existing fake quantize layer and conncts it to the output of a given layer
* @param fq_layer existing fake quantize layer to be copied
* @param last_node the node to which output the new fake quantize layer will be connected
* @return new fake quantize layer or the last node
*/
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer, std::shared_ptr<ngraph::Node> last_node);
} // namespace GNAPluginNS

View File

@ -0,0 +1,215 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include "transformations/init_node_info.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
using namespace ngraph;
using namespace ngraph::opset7;
namespace LayerTestsDefinitions {
enum class modelType {
TranspDWSCTransp = 0, /* Transpose(NHWC->NCHW) => DWSC (Group Convolution) => Transpose(NCHW->NHWC) */
TranspDWSCBiasTransp, /* Transpose(NHWC->NCHW) => DWSC => Broadcasted Add (Bias) => Transpose(NCHW->NHWC) */
};
typedef std::tuple<
InferenceEngine::SizeVector, // Kernel size
InferenceEngine::SizeVector, // Strides
std::vector<ptrdiff_t>, // Pad begin
std::vector<ptrdiff_t>, // Pad end
InferenceEngine::SizeVector, // Dilation
op::PadType, // Padding type
size_t, // Num out channels
size_t, // Num groups
InferenceEngine::SizeVector // Bias
> DWSCParams;
typedef std::tuple<
DWSCParams, // DWSC and bias parameters
InferenceEngine::Precision, // Network Precision
std::string, // Target Device
std::map<std::string, std::string>, // Configuration
InferenceEngine::SizeVector, // Input shapes
modelType // Test model
> DWSCToScaleShiftsParams;
class DWSCToScaleShiftsTest : public testing::WithParamInterface<DWSCToScaleShiftsParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<DWSCToScaleShiftsParams> obj) {
DWSCParams params;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::map<std::string, std::string> configuration;
InferenceEngine::SizeVector inputShape;
modelType model;
std::tie(params, netPrecision, targetDevice, configuration, inputShape, model) = obj.param;
op::PadType padType;
InferenceEngine::SizeVector filter, stride, dilation, bias;
std::vector<ptrdiff_t> padBegin, padEnd;
size_t numOutChannels, numGroups;
std::tie(filter, stride, padBegin, padEnd, dilation, padType, numOutChannels, numGroups, bias) = params;
std::ostringstream result;
result << "M=" << static_cast<uint32_t>(model) << "_";
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "K" << CommonTestUtils::vec2str(filter) << "_";
result << "S" << CommonTestUtils::vec2str(stride) << "_";
result << "PB" << CommonTestUtils::vec2str(padBegin) << "_";
result << "PE" << CommonTestUtils::vec2str(padEnd) << "_";
result << "D=" << CommonTestUtils::vec2str(dilation) << "_";
result << "O=" << numOutChannels << "_";
result << "AP=" << padType << "_";
result << "B=" << CommonTestUtils::vec2str(bias) << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
for (auto const& configItem : configuration) {
result << "_configItem=" << configItem.first << "_" << configItem.second;
}
return result.str();
}
protected:
void SetUp() override {
threshold = 0.05f;
DWSCParams params;
InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape;
modelType model;
std::tie(params, netPrecision, targetDevice, configuration, inputShape, model) = this->GetParam();
op::PadType padType;
InferenceEngine::SizeVector filter, stride, dilation, bias;
std::vector<ptrdiff_t> padBegin, padEnd;
size_t numOutChannels, numGroups;
std::tie(filter, stride, padBegin, padEnd, dilation, padType, numOutChannels, numGroups, bias) = params;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto input = builder::makeParams(ngPrc, {inputShape});
auto transposeInOrder = op::Constant::create(element::i64, Shape{4}, {0, 3, 1, 2});
auto transposeIn = std::make_shared<Transpose>(input[0], transposeInOrder);
auto filterSize = std::accumulate(std::begin(filter), std::end(filter), 1ull, std::multiplies<size_t>());
auto filterWeights = CommonTestUtils::generate_float_numbers(numOutChannels * (inputShape[3] / numGroups) * filterSize, -0.5f, 0.5f);
auto dwsc = builder::makeGroupConvolution(transposeIn, ngPrc, filter, stride, padBegin,
padEnd, dilation, padType, numOutChannels, numGroups, false, filterWeights);
auto transposeOutOrder = op::Constant::create(element::i64, Shape{4}, {0, 2, 3, 1});
auto lastOp = std::make_shared<Transpose>(dwsc, transposeOutOrder);
if (model == modelType::TranspDWSCBiasTransp) {
Shape biasShape{bias};
auto biasWeights = CommonTestUtils::generate_float_numbers(shape_size(biasShape), -1.0f, 1.0f);
auto biasConst = std::make_shared<Constant>(ngPrc, biasShape, biasWeights);
auto bias = std::make_shared<Add>(dwsc, biasConst);
lastOp = std::make_shared<Transpose>(bias, transposeOutOrder);
}
auto result = std::make_shared<Result>(lastOp);
function = std::make_shared<Function>(ResultVector{result}, ParameterVector{input});
}
};
TEST_P(DWSCToScaleShiftsTest, CompareWithRefs) {
Run();
}
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
const std::vector<std::map<std::string, std::string>> configs = {
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "1"},
}
};
const std::vector<op::PadType> padTypes = {
op::PadType::VALID,
op::PadType::EXPLICIT,
op::PadType::SAME_LOWER,
op::PadType::SAME_UPPER
};
const std::vector<modelType> models = {
modelType::TranspDWSCTransp,
modelType::TranspDWSCBiasTransp
};
const std::vector<std::vector<size_t>> inputNHWC = {{1, 1, 5, 32}};
const std::vector<std::vector<size_t >> filters = {{1, 3}};
const std::vector<std::vector<size_t >> strides = {{1, 1}, {1, 2}};
const std::vector<std::vector<ptrdiff_t>> padBegins = {{0, 1}, {0, 2}};
const std::vector<std::vector<ptrdiff_t>> padEnds = {{0, 1}};
const std::vector<std::vector<size_t >> dilations = {{1, 1}};
const std::vector<size_t> numOutChannels = {32};
const std::vector<size_t> numGroups = {32};
const std::vector<std::vector<size_t >> biases = {{1, 32, 1, 1}};
const auto convParams = ::testing::Combine(
::testing::ValuesIn(filters),
::testing::ValuesIn(strides),
::testing::ValuesIn(padBegins),
::testing::ValuesIn(padEnds),
::testing::ValuesIn(dilations),
::testing::ValuesIn(padTypes),
::testing::ValuesIn(numOutChannels),
::testing::ValuesIn(numGroups),
::testing::ValuesIn(biases)
);
INSTANTIATE_TEST_CASE_P(smoke_DWSCToScaleShifts, DWSCToScaleShiftsTest,
::testing::Combine(
convParams,
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(inputNHWC),
::testing::ValuesIn(models)),
DWSCToScaleShiftsTest::getTestCaseName);
/* ============= Strides & Dilations Combination ============= */
const std::vector<op::PadType> padTypesSD = {
op::PadType::VALID,
};
const std::vector<std::vector<size_t>> inputNHWCSD = {{1, 1, 8, 32}};
const std::vector<std::vector<size_t >> dilationsSD = {{1, 1}, {1, 2}};
const auto convParamsSD = ::testing::Combine(
::testing::ValuesIn(filters),
::testing::ValuesIn(strides),
::testing::ValuesIn(padBegins),
::testing::ValuesIn(padEnds),
::testing::ValuesIn(dilationsSD),
::testing::ValuesIn(padTypesSD),
::testing::ValuesIn(numOutChannels),
::testing::ValuesIn(numGroups),
::testing::ValuesIn(biases)
);
INSTANTIATE_TEST_CASE_P(smoke_DWSCToScaleShiftsStridesDilations, DWSCToScaleShiftsTest,
::testing::Combine(
convParamsSD,
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(inputNHWCSD),
::testing::ValuesIn(models)),
DWSCToScaleShiftsTest::getTestCaseName);
} // namespace LayerTestsDefinitions

View File

@ -57,12 +57,12 @@ typedef std::tuple<
std::map<std::string, std::string>, // Configuration
InferenceEngine::SizeVector, // Input shapes
modelType // Test model
> padded2ValidParams;
> paddedToValidParams;
class Padded2ValidConvTest : public testing::WithParamInterface<padded2ValidParams>,
class PaddedToValidConvTest : public testing::WithParamInterface<paddedToValidParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<padded2ValidParams> obj) {
static std::string getTestCaseName(testing::TestParamInfo<paddedToValidParams> obj) {
convSpecificParams convParams;
miscSpecificParams miscParams;
InferenceEngine::Precision netPrecision;
@ -195,26 +195,26 @@ protected:
}
};
class Gna30Padded2ValidConvTest : public Padded2ValidConvTest, GnaLayerTestCheck {
class Gna30PaddedToValidConvTest : public PaddedToValidConvTest, GnaLayerTestCheck {
protected:
void Run() override {
GnaLayerTestCheck::SkipTestCheck();
if (!GnaLayerTestCheck::skipTest) {
Padded2ValidConvTest::Run();
PaddedToValidConvTest::Run();
}
}
void SetUp() override {
Padded2ValidConvTest::SetUp();
PaddedToValidConvTest::SetUp();
}
};
TEST_P(Padded2ValidConvTest, CompareWithRefs) {
TEST_P(PaddedToValidConvTest, CompareWithRefs) {
Run();
}
TEST_P(Gna30Padded2ValidConvTest, CompareWithRefs) {
TEST_P(Gna30PaddedToValidConvTest, CompareWithRefs) {
Run();
}
@ -322,7 +322,7 @@ const auto misc2DParams = ::testing::Combine(
::testing::ValuesIn(maxpool2DStrides)
);
INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Padded2ValidConvTest,
INSTANTIATE_TEST_CASE_P(smoke_1DPaddedToValid, PaddedToValidConvTest,
::testing::Combine(
conv1DParams,
misc1DParams,
@ -331,9 +331,9 @@ INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Padded2ValidConvTest,
::testing::ValuesIn(configs1D),
::testing::ValuesIn(input1DNHWC),
::testing::ValuesIn(models)),
Padded2ValidConvTest::getTestCaseName);
PaddedToValidConvTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Gna30Padded2ValidConvTest,
INSTANTIATE_TEST_CASE_P(smoke_1DPaddedToValid, Gna30PaddedToValidConvTest,
::testing::Combine(
conv1DParams,
misc1DParams,
@ -342,9 +342,9 @@ INSTANTIATE_TEST_CASE_P(smoke_1DPadded2Valid, Gna30Padded2ValidConvTest,
::testing::ValuesIn(configs1D_Gna30),
::testing::ValuesIn(input1DNHWC),
::testing::ValuesIn(models)),
Gna30Padded2ValidConvTest::getTestCaseName);
Gna30PaddedToValidConvTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_2DPadded2Valid, Gna30Padded2ValidConvTest,
INSTANTIATE_TEST_CASE_P(smoke_2DPaddedToValid, Gna30PaddedToValidConvTest,
::testing::Combine(
conv2DParams,
misc2DParams,
@ -353,6 +353,6 @@ INSTANTIATE_TEST_CASE_P(smoke_2DPadded2Valid, Gna30Padded2ValidConvTest,
::testing::ValuesIn(configs2D),
::testing::ValuesIn(input2DNHWC),
::testing::ValuesIn(models)),
Gna30Padded2ValidConvTest::getTestCaseName);
Gna30PaddedToValidConvTest::getTestCaseName);
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,384 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <tuple>
#include "transformations/convert_dwsc_to_scaleshifts.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
namespace testing {
namespace {
enum class modelType {
TranspDWSCTransp = 0, /* Transpose(NHWC->NCHW) => DWSC (Group Convolution) => Transpose(NCHW->NHWC) */
TranspDWSCBiasTransp, /* Transpose(NHWC->NCHW) => DWSC => Broadcasted Add (Bias) => Transpose(NCHW->NHWC) */
};
typedef std::tuple<
modelType, // Test model
ngraph::Shape, // Input shape
ngraph::Shape, // Convolution filter shape
ngraph::Strides, // Convolution stride
ngraph::CoordinateDiff, // Convolution pads begin
ngraph::CoordinateDiff, // Convolution pads end
ngraph::Strides, // Convolution dilation
ngraph::Shape, // Bias shape
ngraph::op::PadType // Padding type
> DWSCToScaleShiftsParams;
typedef std::tuple<
bool, // With / without Fake Quantize layers
DWSCToScaleShiftsParams // Test parameters
> fqDWSCToScaleShiftsParams;
std::shared_ptr<ngraph::opset7::FakeQuantize> createFQ(std::shared_ptr<ngraph::Node>& in_node) {
auto input_low = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {5});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
return std::make_shared<ngraph::opset7::FakeQuantize>(in_node, input_low, input_high, output_low, output_high, 11);
}
std::shared_ptr<ngraph::Node> createBiasFQ(const std::shared_ptr<ngraph::Node>& in_node,
std::shared_ptr<ngraph::opset7::Constant>& bias_const, const bool& fq) {
std::shared_ptr<ngraph::Node> node;
node = std::make_shared<ngraph::opset7::Add>(in_node, bias_const);
if (fq) {
node = createFQ(node);
}
return node;
}
std::shared_ptr<ngraph::opset7::Result> createFunction(const bool& fq,
const modelType& model,
const ngraph::Output<ngraph::Node>& input_node,
const ngraph::Shape& filters_shape,
const ngraph::Strides& conv_stride,
const ngraph::CoordinateDiff& pads_begin,
const ngraph::CoordinateDiff& pads_end,
const ngraph::Strides& conv_dilation,
const ngraph::Shape& bias_shape,
const ngraph::op::PadType& pad_type,
std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
std::shared_ptr<ngraph::opset7::Constant>& bias_const,
std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
std::shared_ptr<ngraph::Node> fq_filters;
auto transpose_in_order = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, ngraph::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
auto transpose_in = std::make_shared<ngraph::opset7::Transpose>(input_node, transpose_in_order);
if (fq) {
fq_filters = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64,
ngraph::Shape{input_node.get_shape()[3], 1, filters_shape[0], filters_shape[1]});
fq_filters = createFQ(fq_filters);
fq_filters = std::make_shared<ngraph::opset7::Reshape>(fq_filters,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5},
ngraph::Shape{input_node.get_shape()[3], 1, 1, filters_shape[0], filters_shape[1]}), false);
} else {
fq_filters = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64,
ngraph::Shape{input_node.get_shape()[3], 1, 1, filters_shape[0], filters_shape[1]});
}
dwsc = std::make_shared<ngraph::opset7::GroupConvolution>(transpose_in, fq_filters, conv_stride, pads_begin, pads_end, conv_dilation, pad_type);
auto transpose_out_order = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, ngraph::Shape{4}, std::vector<int64_t>{0, 2, 3, 1});
auto last_op = std::make_shared<ngraph::opset7::Transpose>(dwsc, transpose_out_order);
if (model == modelType::TranspDWSCBiasTransp || fq) {
bias_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::i64, bias_shape);
auto bias = createBiasFQ(dwsc, bias_const, fq);
fq_bias = std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(bias);
last_op = std::make_shared<ngraph::opset7::Transpose>(bias, transpose_out_order);
}
return std::make_shared<ngraph::opset7::Result>(last_op);
}
std::shared_ptr<ngraph::Function> get_initial_function(const bool& fq,
const modelType& model,
const ngraph::Shape& input_shape,
const ngraph::Shape& filters_shape,
const ngraph::Strides& conv_stride,
const ngraph::CoordinateDiff& pads_begin,
const ngraph::CoordinateDiff& pads_end,
const ngraph::Strides& conv_dilation,
const ngraph::Shape& bias_shape,
const ngraph::op::PadType& pad_type,
std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
std::shared_ptr<ngraph::opset7::Constant>& bias_const,
std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
auto result = createFunction(fq, model, input_params, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type, dwsc, bias_const, fq_bias);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}
// ---------------------------------------------------------------------------------------------------------------------
class ConvertDWSCToScaleShiftsTestInvalidFixture : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqDWSCToScaleShiftsParams> {
public:
void SetUp() override;
public:
std::shared_ptr<ngraph::Function> function, reference_function;
modelType model;
};
void ConvertDWSCToScaleShiftsTestInvalidFixture::SetUp() {
bool fq;
DWSCToScaleShiftsParams params;
ngraph::Shape input_shape;
ngraph::Shape filters_shape, bias_shape;
ngraph::Strides conv_stride, conv_dilation;
ngraph::CoordinateDiff pads_begin, pads_end;
ngraph::op::PadType pad_type;
std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc;
std::shared_ptr<ngraph::opset7::Constant> bias_const;
std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias;
std::tie(fq, params) = this->GetParam();
std::tie(model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type) = params;
function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type, dwsc, bias_const, fq_bias);
reference_function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type, dwsc, bias_const, fq_bias);
}
// ---------------------------------------------------------------------------------------------------------------------
class ConvertDWSCToScaleShiftsTestFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqDWSCToScaleShiftsParams> {
public:
void SetUp() override;
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
const modelType& model,
const ngraph::Shape& input_shape,
const ngraph::Shape& filters_shape,
const ngraph::Strides& conv_stride,
const ngraph::CoordinateDiff& pads_begin,
const ngraph::CoordinateDiff& pads_end,
const ngraph::Strides& conv_dilation,
const ngraph::Shape& bias_shape,
const ngraph::op::PadType& pad_type,
const std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
const std::shared_ptr<ngraph::opset7::Constant>& bias_const,
const std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias);
public:
std::shared_ptr<ngraph::Function> function, reference_function;
modelType model;
};
void ConvertDWSCToScaleShiftsTestFixture::SetUp() {
bool fq;
DWSCToScaleShiftsParams params;
ngraph::Shape input_shape;
ngraph::Shape filters_shape, bias_shape;
ngraph::Strides conv_stride, conv_dilation;
ngraph::CoordinateDiff pads_begin, pads_end;
ngraph::op::PadType pad_type;
std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc;
std::shared_ptr<ngraph::opset7::Constant> bias_const;
std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias;
std::tie(fq, params) = this->GetParam();
std::tie(model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type) = params;
function = get_initial_function(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type, dwsc, bias_const, fq_bias);
reference_function = get_reference(fq, model, input_shape, filters_shape, conv_stride, pads_begin, pads_end, conv_dilation,
bias_shape, pad_type, dwsc, bias_const, fq_bias);
}
std::shared_ptr<ngraph::opset7::StridedSlice> FlatCrop(ngraph::Output<ngraph::Node> input, size_t offset, size_t size) {
return std::make_shared<ngraph::opset7::StridedSlice>(
input, // data
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)0, offset}), // begin sice index
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)0, offset + size}), // end slice index
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {(size_t)1, (size_t)1}), // strides
std::vector<int64_t>{1, 0}, // begin mask
std::vector<int64_t>{1, 0}); // end mask
}
std::shared_ptr<ngraph::Node> InsertFQLayer(const std::shared_ptr<ngraph::opset7::FakeQuantize> fq_layer,
std::shared_ptr<ngraph::Node> last_node) {
if (fq_layer != nullptr) {
return fq_layer->clone_with_new_inputs({last_node,
fq_layer->input_value(1), fq_layer->input_value(2),
fq_layer->input_value(3), fq_layer->input_value(4)});
}
return last_node;
}
std::shared_ptr<ngraph::Node> DecomposeDWSC(std::shared_ptr<ngraph::opset7::GroupConvolution> dwsc,
std::shared_ptr<ngraph::opset7::Constant> bias_const, std::shared_ptr<ngraph::opset7::FakeQuantize> fq_bias,
std::shared_ptr<ngraph::opset7::Reshape> flat_input_plane, std::shared_ptr<ngraph::Node> flat_filters_plane) {
std::shared_ptr<ngraph::opset7::Constant> const_zero_padding;
std::shared_ptr<ngraph::Node> reshaped_bias;
ngraph::OutputVector output_chunks;
auto input_channel_count = dwsc->get_input_shape(0)[1];
auto input_width = dwsc->get_input_shape(0)[3];
auto output_width = dwsc->get_output_shape(0)[3];
auto filter_width = dwsc->get_input_shape(1)[4];
auto pads_begin = dwsc->get_pads_begin()[1];
auto stride_width = dwsc->get_strides()[1];
auto dilation_width = dwsc->get_dilations()[1];
// Constant with zero padding
if (pads_begin) {
const_zero_padding = std::make_shared<ngraph::opset7::Constant>(dwsc->get_element_type(), ngraph::Shape{1, input_channel_count}, 0);
}
// Reshape bias const
if (bias_const) {
auto bias_size = shape_size(bias_const->get_shape());
reshaped_bias = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(bias_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, bias_size}), false);
}
// Move filter over input performing multiplication and addition (scaleshift), take padding, stride, dilation and bias into account
for (int32_t input_position = -pads_begin, o = 0; o < output_width; input_position += stride_width, o++) {
std::shared_ptr<ngraph::Node> previous_layer_output, last_layer_output;
int32_t filter_end = input_position + filter_width * dilation_width;
bool first = true;
filter_end = filter_end < input_width ? filter_end : input_width;
for (int32_t filter_pos = input_position, filter_idx = 0; filter_pos < filter_end; filter_pos += dilation_width, filter_idx++) {
if (filter_pos >= 0) {
auto conv_input_slice = FlatCrop(flat_input_plane, filter_pos * input_channel_count, input_channel_count);
auto conv_filter_slice = FlatCrop(flat_filters_plane, filter_idx * input_channel_count, input_channel_count);
if (first) {
first = false;
previous_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
if (bias_const) {
previous_layer_output = std::make_shared<ngraph::opset7::Add>(previous_layer_output, reshaped_bias);
previous_layer_output = InsertFQLayer(fq_bias, previous_layer_output);
}
last_layer_output = previous_layer_output;
} else {
last_layer_output = std::make_shared<ngraph::opset7::Multiply>(conv_input_slice, conv_filter_slice);
last_layer_output = std::make_shared<ngraph::opset7::Add>(last_layer_output, previous_layer_output);
previous_layer_output = last_layer_output;
}
}
}
if (!last_layer_output) {
IE_ASSERT(const_zero_padding);
last_layer_output = const_zero_padding;
}
output_chunks.push_back(last_layer_output);
}
// Concat and transpose is only needed when output width > 1
if (output_chunks.size() > 1) {
return std::make_shared<ngraph::opset7::Concat>(output_chunks, 0);
}
return output_chunks[0].get_node_shared_ptr();
}
std::shared_ptr<ngraph::Function> ConvertDWSCToScaleShiftsTestFixture::get_reference(const bool& fq,
const modelType& model,
const ngraph::Shape& input_shape,
const ngraph::Shape& filters_shape,
const ngraph::Strides& conv_stride,
const ngraph::CoordinateDiff& pads_begin,
const ngraph::CoordinateDiff& pads_end,
const ngraph::Strides& conv_dilation,
const ngraph::Shape& bias_shape,
const ngraph::op::PadType& pad_type,
const std::shared_ptr<ngraph::opset7::GroupConvolution>& dwsc,
const std::shared_ptr<ngraph::opset7::Constant>& bias_const,
const std::shared_ptr<ngraph::opset7::FakeQuantize>& fq_bias) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
auto output_channel_count = dwsc->get_output_shape(0)[1];
auto output_width = dwsc->get_output_shape(0)[3];
// Prepare flat input data
auto flat_input_plane = std::make_shared<ngraph::opset7::Reshape>(input_params,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
ngraph::Shape{1, ngraph::shape_size(input_shape)}), false);
// Prepare flat filter data
auto filters_const = std::dynamic_pointer_cast<ngraph::Node>(dwsc->get_input_node_shared_ptr(1));
auto filters_size = ngraph::shape_size(filters_const->get_shape());
auto transposed_filters_const = ngraph::op::util::make_try_fold<ngraph::opset7::Transpose>(filters_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{5}, ngraph::Shape{4, 1, 2, 3, 0}));
auto flat_filters_plane = ngraph::op::util::make_try_fold<ngraph::opset7::Reshape>(transposed_filters_const,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, filters_size}), false);
// Convert DWSC to a set of diagonal layers
auto output_plane = DecomposeDWSC(dwsc, bias_const, fq_bias, flat_input_plane, flat_filters_plane);
// Restore the original output shape
auto result = std::make_shared<ngraph::opset7::Reshape>(output_plane,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
ngraph::Shape{1, output_channel_count, 1, output_width}), false);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{std::make_shared<ngraph::opset7::Result>(result)}, ngraph::ParameterVector{input_params});
}
// ---------------------------------------------------------------------------------------------------------------------
void execute_test(modelType model, std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<GNAPluginNS::ConvertDWSCToScaleShifts>();
manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
TEST_P(ConvertDWSCToScaleShiftsTestFixture, CompareFunctions) {
execute_test(model, function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(ConvertDWSCToScaleShiftsTestSuite, ConvertDWSCToScaleShiftsTestFixture,
::testing::Combine(
// With / without Fake Quantize layers
::testing::Values(true, false),
::testing::Values(
std::make_tuple(modelType::TranspDWSCTransp, ngraph::Shape{1, 1, 5, 32}, ngraph::Shape{1, 3}, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 1}, ngraph::CoordinateDiff{0, 1}, ngraph::Strides{1, 1},
ngraph::Shape{1, 32, 1, 1}, ngraph::op::PadType::VALID),
std::make_tuple(modelType::TranspDWSCBiasTransp, ngraph::Shape{1, 1, 5, 32}, ngraph::Shape{1, 3}, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 2}, ngraph::Strides{1, 1},
ngraph::Shape{1, 32, 1, 1}, ngraph::op::PadType::VALID))));
TEST_P(ConvertDWSCToScaleShiftsTestInvalidFixture, CompareFunctions) {
execute_test(model, function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(ConvertDWSCToScaleShiftsInvalidTestSuite, ConvertDWSCToScaleShiftsTestInvalidFixture,
::testing::Combine(
// With / without Fake Quantize layers
::testing::Values(true, false),
::testing::Values(
std::make_tuple(modelType::TranspDWSCTransp, ngraph::Shape{2, 16, 8, 1}, ngraph::Shape{1, 2}, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
ngraph::Shape{1, 4, 1, 1}, ngraph::op::PadType::SAME_UPPER),
std::make_tuple(modelType::TranspDWSCBiasTransp, ngraph::Shape{2, 16, 8, 1}, ngraph::Shape{1, 2}, ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
ngraph::Shape{1, 4, 1, 1}, ngraph::op::PadType::EXPLICIT))));
} // namespace
} // namespace testing

View File

@ -6,7 +6,7 @@
#include <tuple>
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/convert_padded_to_valid_convolution.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
@ -39,12 +39,12 @@ typedef std::tuple<
ngraph::Strides, // Max Pool stride
ngraph::Shape, // Max Pool shape
ngraph::op::PadType // Padding type
> padded2ValidConvParams;
> paddedToValidConvParams;
typedef std::tuple<
bool, // With / without Fake Quantize layers
padded2ValidConvParams // Test parameters
> fqPadded2ValidConvParams;
paddedToValidConvParams // Test parameters
> fqPaddedToValidConvParams;
struct ConvData {
size_t input_height;
@ -193,17 +193,17 @@ std::shared_ptr<ngraph::Function> get_initial_function(const bool& fq,
// ---------------------------------------------------------------------------------------------------------------------
class ConvertPadded2ValidConvTestInvalidFixture : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqPadded2ValidConvParams> {
class ConvertPaddedToValidConvTestInvalidFixture : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqPaddedToValidConvParams> {
public:
void SetUp() override;
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void ConvertPadded2ValidConvTestInvalidFixture::SetUp() {
void ConvertPaddedToValidConvTestInvalidFixture::SetUp() {
bool fq;
padded2ValidConvParams params;
paddedToValidConvParams params;
modelType model;
ngraph::PartialShape input_shape;
ngraph::Shape filters_shape, bias_shape, maxpool_shape;
@ -223,8 +223,8 @@ void ConvertPadded2ValidConvTestInvalidFixture::SetUp() {
// ---------------------------------------------------------------------------------------------------------------------
class ConvertPadded2ValidConvTestFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqPadded2ValidConvParams> {
class ConvertPaddedToValidConvTestFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<fqPaddedToValidConvParams> {
public:
void SetUp() override;
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
@ -244,9 +244,9 @@ public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void ConvertPadded2ValidConvTestFixture::SetUp() {
void ConvertPaddedToValidConvTestFixture::SetUp() {
bool fq;
padded2ValidConvParams params;
paddedToValidConvParams params;
modelType model;
ngraph::PartialShape input_shape;
ngraph::Shape filters_shape, bias_shape, maxpool_shape;
@ -354,7 +354,7 @@ std::shared_ptr<ngraph::Node> CreatePaddedNet(const ngraph::Output<ngraph::Node>
return padded_input_plane;
}
std::shared_ptr<ngraph::Function> ConvertPadded2ValidConvTestFixture::get_reference(const bool& fq,
std::shared_ptr<ngraph::Function> ConvertPaddedToValidConvTestFixture::get_reference(const bool& fq,
const modelType& model,
const ngraph::PartialShape& input_shape,
const ngraph::Shape& filters_shape,
@ -406,18 +406,18 @@ std::shared_ptr<ngraph::Function> ConvertPadded2ValidConvTestFixture::get_refere
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<GNAPluginNS::ConvertPadded2ValidConv>();
manager.register_pass<GNAPluginNS::ConvertPaddedToValidConv>();
manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
TEST_P(ConvertPadded2ValidConvTestFixture, CompareFunctions) {
TEST_P(ConvertPaddedToValidConvTestFixture, CompareFunctions) {
execute_test(function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvTestSuite, ConvertPadded2ValidConvTestFixture,
INSTANTIATE_TEST_SUITE_P(ConvertPaddedToValidConvTestSuite, ConvertPaddedToValidConvTestFixture,
::testing::Combine(
// With / without Fake Quantize layers
::testing::Values(true, false),
@ -444,11 +444,11 @@ INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvTestSuite, ConvertPadded2ValidCo
ngraph::CoordinateDiff{0, 2}, ngraph::CoordinateDiff{0, 3}, ngraph::Strides{1, 1},
ngraph::Shape{1, 1, 1, 4}, ngraph::Strides{1, 1}, ngraph::Shape{1, 2}, ngraph::op::PadType::EXPLICIT))));
TEST_P(ConvertPadded2ValidConvTestInvalidFixture, CompareFunctions) {
TEST_P(ConvertPaddedToValidConvTestInvalidFixture, CompareFunctions) {
execute_test(function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(ConvertPadded2ValidConvInvalidTestSuite, ConvertPadded2ValidConvTestInvalidFixture,
INSTANTIATE_TEST_SUITE_P(ConvertPaddedToValidConvInvalidTestSuite, ConvertPaddedToValidConvTestInvalidFixture,
::testing::Combine(
// With / without Fake Quantize layers
::testing::Values(true, false),

View File

@ -6,7 +6,7 @@
#include <tuple>
#include "transformations/decompose_2d_conv.hpp"
#include "transformations/decompose_2d_convolution.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
@ -426,7 +426,7 @@ void TransformInput(const GraphData& graph_data, const ConvParams& conv_params,
*/
// First we need to prepare flat (height = 1) slices of input data proper for flattened (height = 1) filters created later on;
// the input datat is overlapping (duplicated)
// the input data is overlapping (duplicated)
ngraph::OutputVector dilated_input_planes;
for (size_t filter_height = 0; filter_height < conv_params.filter_height; filter_height++) {
size_t offset;
@ -704,10 +704,13 @@ void execute_test(modelType model, std::shared_ptr<ngraph::Function> function, s
case modelType::TranspConvBcastAddActTransp:
case modelType::TranspConvBcastAddMaxPoolActTransp:
manager.register_pass<GNAPluginNS::Decompose2DConv>();
break;
case modelType::TranspConvTranspBcastAdd:
manager.register_pass<GNAPluginNS::Decompose2DConvTransposedWithBias>();
break;
case modelType::TranspConvTranspBcastAddAct:
manager.register_pass<GNAPluginNS::Decompose2DConvTransposedWithBiasAF>();
break;
}
manager.run_passes(function);