[GNA] Support new kaldi irs (#9474)

* Support new kaldi IRs (generated in NHWC layout)

* Update tests with activation and fq

* Cleanup

* Fix reordering FQ and MaxPool and problem with overflow

* Fix win

* Update src/plugins/intel_gna/transformations/unfuse_reshape_and_transpose.hpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Update src/plugins/intel_gna/transformations/unfuse_reshape_and_transpose.cpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Update inference-engine/tests/unit/gna/ngraph/transformations/gna_unfuse_reshape_and_transpose.cpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Code review

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>
This commit is contained in:
Nadezhda Ageeva 2022-01-17 14:16:23 +03:00 committed by GitHub
parent 3c7589184d
commit 56581dbe2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 697 additions and 79 deletions

View File

@ -0,0 +1,225 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "transformations/unfuse_reshape_and_transpose.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
namespace {
class IActivationFactory {
public:
virtual ~IActivationFactory() = default;
virtual std::shared_ptr<ngraph::Node> createNode(const ngraph::Output<ngraph::Node>& in) = 0;
};
template <typename T>
class ActivationFactory : public IActivationFactory {
public:
ActivationFactory() = default;
std::shared_ptr<ngraph::Node> createNode(const ngraph::Output<ngraph::Node>& operation_before) override {
return std::make_shared<T>(operation_before);
}
private:
ActivationFactory(const ActivationFactory&) = delete;
ActivationFactory& operator=(const ActivationFactory& ) = delete;
};
template <>
class ActivationFactory <ngraph::opset8::Clamp> : public IActivationFactory {
public:
ActivationFactory(const double min, const double max) : min_(min), max_(max) {}
std::shared_ptr<ngraph::Node> createNode(const ngraph::Output<ngraph::Node>& operation_before) override {
return std::make_shared<ngraph::opset8::Clamp>(operation_before, min_, max_);
}
private:
ActivationFactory(const ActivationFactory&) = delete;
ActivationFactory& operator=(const ActivationFactory& ) = delete;
private:
const double min_;
const double max_;
};
using ActivationFactoryPtr = std::shared_ptr<IActivationFactory>;
template <typename T, typename ... Args>
ActivationFactoryPtr createActivationFactory(Args&& ... args) {
return std::make_shared<ActivationFactory<T>>(std::forward<Args>(args) ...);
}
static std::shared_ptr<ngraph::Function> createFunction(const ngraph::Shape& conv_input_shape,
const ngraph::Shape& conv_filter_shape,
bool with_bias,
bool with_pool,
ActivationFactoryPtr activation_factory,
bool with_fq,
bool single_reshape_before,
bool single_reshape_after) {
size_t total_in = std::accumulate(std::begin(conv_input_shape), std::end(conv_input_shape), 1, std::multiplies<int>());
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, ngraph::Shape{1, total_in});
std::shared_ptr<ngraph::Node> last_node, last_const;
auto add_fake_quantize = [&](const std::shared_ptr<ngraph::Node>& node) {
auto input_low = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {5});
auto output_low = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
return std::make_shared<ngraph::opset8::FakeQuantize>(node, input_low, input_high, output_low, output_high, 11);
};
if (single_reshape_before) {
auto reshape_in_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, conv_input_shape);
auto reshape_in = std::make_shared<ngraph::opset8::Reshape>(input, reshape_in_const, false);
last_node = reshape_in;
} else {
auto reshape_in_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4},
ngraph::Shape{conv_input_shape[0], conv_input_shape[2], conv_input_shape[3], conv_input_shape[1]});
auto reshape_in = std::make_shared<ngraph::opset8::Reshape>(input, reshape_in_const, false);
auto transpose_in_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{0, 3, 1, 2});
auto transpose_in = std::make_shared<ngraph::opset8::Transpose>(reshape_in, transpose_in_const);
last_node = transpose_in;
}
auto conv_weights = ngraph::opset8::Constant::create(ngraph::element::f32, conv_filter_shape, {1});
last_const = conv_weights;
if (with_fq) {
auto conv_input_fq = add_fake_quantize(last_node);
last_node = conv_input_fq;
auto conv_weights_fq = add_fake_quantize(conv_weights);
last_const = conv_weights_fq;
}
auto conv = std::make_shared<ngraph::opset8::Convolution>(last_node,
last_const,
ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0},
ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1});
last_node = conv;
auto conv_output_shape = conv->get_output_shape(0);
size_t total_out = std::accumulate(std::begin(conv_output_shape), std::end(conv_output_shape), 1, std::multiplies<int>());
if (with_bias) {
auto add_const = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1, conv_output_shape.at(1), 1, 1}, {1});
auto add = std::make_shared<ngraph::opset8::Add>(conv, add_const);
last_node = add;
}
if (with_fq) {
auto conv_bias_fq = add_fake_quantize(last_node);
last_node = conv_bias_fq;
}
if (with_pool) {
auto pool = std::make_shared<ngraph::opset7::MaxPool>(last_node,
ngraph::Strides{1, 1}, ngraph::Shape{0, 0}, ngraph::Shape{0, 0}, ngraph::Shape{1, 1});
last_node = pool;
}
if (activation_factory) {
if (with_fq) {
auto act_fq_in = add_fake_quantize(last_node);
last_node = act_fq_in;
}
auto act = activation_factory->createNode(last_node);
last_node = act;
if (with_fq) {
auto act_fq_out = add_fake_quantize(last_node);
last_node = act_fq_out;
}
}
auto reshape_out_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, total_out});
if (!single_reshape_after) {
auto transpose_out_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{0, 2, 3, 1});
auto transpose_out = std::make_shared<ngraph::opset8::Transpose>(last_node, transpose_out_const);
last_node = transpose_out;
}
auto reshape_out = std::make_shared<ngraph::opset8::Reshape>(last_node, reshape_out_const, false);
auto result = std::make_shared<ngraph::opset8::Result>(reshape_out);
auto func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input});
return func;
}
typedef std::tuple<
std::tuple<ngraph::Shape, ngraph::Shape, bool, bool>,
bool, // with bias
bool, // with pooling
ActivationFactoryPtr, // with activation
bool // with fq
> UnfuseReshapeAndTransposeParams;
class UnfuseReshapeAndTransposeTestSuiteFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<UnfuseReshapeAndTransposeParams> {
public:
void SetUp() override;
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void UnfuseReshapeAndTransposeTestSuiteFixture::SetUp() {
std::tuple<ngraph::Shape, ngraph::Shape, bool, bool> conv_data;
bool with_bias;
bool with_pool;
bool with_fq;
ActivationFactoryPtr af;
std::tie(conv_data, with_bias, with_pool, af, with_fq) = this->GetParam();
ngraph::Shape conv_input_shape;
ngraph::Shape conv_filter_shape;
bool replace_before;
bool replace_after;
std::tie(conv_input_shape, conv_filter_shape, replace_before, replace_after) = conv_data;
function = createFunction(conv_input_shape, conv_filter_shape, with_bias, with_pool, af, with_fq, true, true);
reference_function = createFunction(conv_input_shape, conv_filter_shape, with_bias, with_pool, af, with_fq, !replace_before, !replace_after);
}
void execute_test(std::shared_ptr<ngraph::Function> function,
std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<GNAPluginNS::Unfuse2dto4dReshapeAndTranspose>();
manager.register_pass<GNAPluginNS::Unfuse4dto2dReshapeAndTranspose>();
manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}
TEST_P(UnfuseReshapeAndTransposeTestSuiteFixture, CompareFunctions) {
execute_test(function, reference_function);
}
const std::vector<ActivationFactoryPtr> activationFactories = {
nullptr,
createActivationFactory<ngraph::opset8::Relu>(),
createActivationFactory<ngraph::opset8::Sigmoid>(),
createActivationFactory<ngraph::opset8::Tanh>(),
createActivationFactory<ngraph::opset8::Abs>(),
createActivationFactory<ngraph::opset8::Log>(),
createActivationFactory<ngraph::opset8::Exp>(),
createActivationFactory<ngraph::opset8::Sign>(),
createActivationFactory<ngraph::opset8::Clamp>(0.1, 0.2)
};
INSTANTIATE_TEST_SUITE_P(UnfuseReshapeAndTransposeTestSuite, UnfuseReshapeAndTransposeTestSuiteFixture,
::testing::Combine(
::testing::ValuesIn(
std::vector<std::tuple<ngraph::Shape, ngraph::Shape, bool, bool>>{
{ngraph::Shape{1, 1, 1, 168}, ngraph::Shape{12, 1, 1, 8}, true, false},
{ngraph::Shape{1, 1, 1, 640}, ngraph::Shape{256, 1, 1, 512}, true, false},
{ngraph::Shape{1, 1, 1, 1024}, ngraph::Shape{256, 1, 1, 512}, true, false},
{ngraph::Shape{1, 1, 33, 32}, ngraph::Shape{128, 1, 33, 9}, true, false},
{ngraph::Shape{1, 1, 11, 13}, ngraph::Shape{128, 1, 11, 9}, true, false},
{ngraph::Shape{1, 1, 33, 23}, ngraph::Shape{128, 1, 11, 5}, true, false},
{ngraph::Shape{1, 1, 33, 32}, ngraph::Shape{1, 1, 33, 9}, true, true},
{ngraph::Shape{1, 1, 1, 1024}, ngraph::Shape{256, 1, 1, 1024}, true, true},
{ngraph::Shape{1, 1, 33, 32}, ngraph::Shape{1, 1, 33, 9}, true, true}}),
::testing::ValuesIn(std::vector<bool>{true, false}), // with bias
::testing::ValuesIn(std::vector<bool>{true, false}), // with max pool
::testing::ValuesIn(activationFactories), // with activation
::testing::ValuesIn(std::vector<bool>{true, false}))); // with fq
} // namespace
} // namespace testing

View File

@ -266,13 +266,21 @@ void GNAGraphCompiler::ConvolutionPrimitive(InferenceEngine::CNNLayerPtr layer)
std::swap(convolution._dilation_x, convolution._dilation_y); std::swap(convolution._dilation_x, convolution._dilation_y);
} }
auto in_kernel_w = convolution._kernel_x;
auto in_kernel_h = convolution._kernel_y;
bool transpose_h_w = false;
// Map 2d convolution to 1d if it's possible // Map 2d convolution to 1d if it's possible
if (GNAConvolutionLayer::isMappableFrom2DTo1D(in_height, in_width, convolution._kernel_x, convolution._stride_x)) { if (GNAConvolutionLayer::isMappableFrom2DTo1D(in_height, in_width, in_channels,
convolution._kernel_y, convolution._kernel_x,
convolution._stride_y, convolution._stride_x)) {
transpose_h_w = (in_height == convolution._kernel_y);
in_width *= in_height; in_width *= in_height;
in_height = 1; in_height = 1;
out_width *= out_height; out_width *= out_height;
out_height = 1; out_height = 1;
convolution._stride_x *= (convolution._stride_y * convolution._kernel_x); convolution._stride_x *= transpose_h_w ? (convolution._stride_y * convolution._kernel_y) :
(convolution._stride_y * convolution._kernel_x);
convolution._kernel_x *= convolution._kernel_y; convolution._kernel_x *= convolution._kernel_y;
convolution._kernel_y = 1; convolution._kernel_y = 1;
} }
@ -304,19 +312,20 @@ void GNAGraphCompiler::ConvolutionPrimitive(InferenceEngine::CNNLayerPtr layer)
in_height != 1) { in_height != 1) {
// TensorFlow default layout is NHWC // TensorFlow default layout is NHWC
// OpenVino Default layout is NCHW // OpenVino Default layout is NCHW
// GNA Convolution input is NHCW // GNA Convolution input is NHCW (old) or NHWC (new)
// When layer layout is in NHWC it means that is was created by PassManager // When layer layout is in NHWC it means that is was created by PassManager
return finalizeConvolution2DPrimitive(layer, in_batch, in_channels, in_height, in_width, return finalizeConvolution2DPrimitive(layer, in_batch, in_channels, in_height, in_width,
out_batch, out_channels, out_height, out_width); out_batch, out_channels, out_height, out_width);
THROW_GNA_LAYER_EXCEPTION(layer) << "Convolution 2D is not supported on GNA 1.0 library"; THROW_GNA_LAYER_EXCEPTION(layer) << "Convolution 2D is not supported on GNA 1.0 library";
} }
finalizeConvolution1DPrimitive(layer, in_batch, in_channels, in_width, finalizeConvolution1DPrimitive(layer, in_batch, in_channels, in_width,
out_batch, out_channels, out_width); out_batch, out_channels, out_width, in_kernel_w, in_kernel_h, transpose_h_w);
} }
void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerPtr layer, void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerPtr layer,
uint32_t in_batch, uint32_t in_channels, uint32_t in_width, uint32_t in_batch, uint32_t in_channels, uint32_t in_width,
uint32_t out_batch, uint32_t out_channels, uint32_t out_width) { uint32_t out_batch, uint32_t out_channels, uint32_t out_width,
uint32_t in_kernel_w, uint32_t in_kernel_h, bool transpose_h_w) {
auto& convolution = dynamic_cast<ConvolutionLayer&>(*layer.get()); auto& convolution = dynamic_cast<ConvolutionLayer&>(*layer.get());
printConvolutionLayer(convolution); printConvolutionLayer(convolution);
@ -429,7 +438,10 @@ void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerP
ptr_weights, ptr_weights,
ptr_biases); ptr_biases);
if (inputs->getLayout() == Layout::NHWC) { // Keep both variants of kaldi models working:
// Old one has layout which is different from NHWC
// New one has layout NHWC, but it is mapped from 2d by H
if (inputs->getLayout() == Layout::NHWC && !transpose_h_w) {
currentComponent.orientation_in = kDnnInterleavedOrientation; currentComponent.orientation_in = kDnnInterleavedOrientation;
currentComponent.orientation_out = kDnnInterleavedOrientation; currentComponent.orientation_out = kDnnInterleavedOrientation;
} }
@ -447,7 +459,7 @@ void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerP
// TODO: convolution might be not the first layer in sorted order but connected via split for example - dont know how kaldi will handle that // TODO: convolution might be not the first layer in sorted order but connected via split for example - dont know how kaldi will handle that
if (!dnn->do_rotate_input) { if (!dnn->do_rotate_input) {
if (inputs->getLayout() != Layout::NHWC && LayerInfo(connectedInputLayer).isInput()) { if ((inputs->getLayout() != Layout::NHWC || transpose_h_w) && LayerInfo(connectedInputLayer).isInput()) {
// Kaldi features are opposite orientation // Kaldi features are opposite orientation
dnn->do_rotate_input = true; dnn->do_rotate_input = true;
dnn->num_rotate_rows = effectiveStride; dnn->num_rotate_rows = effectiveStride;
@ -459,12 +471,16 @@ void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerP
connectOutput(layer, ptr_outputs, num_data_bytes_out); connectOutput(layer, ptr_outputs, num_data_bytes_out);
// Transpose H with W or C with HW
auto A = transpose_h_w ? in_kernel_h : in_channels;
auto B = transpose_h_w ? in_kernel_w : convolution._kernel[X_AXIS];
std::vector<uint8_t> transposedWeights; std::vector<uint8_t> transposedWeights;
for (uint32_t k = 0; k < convolution._out_depth; k++) { for (uint32_t k = 0; k < convolution._out_depth; k++) {
uint8_t * ptr_filt_current uint8_t * ptr_filt_current
= convolution._weights->cbuffer().as<uint8_t*>() + = convolution._weights->cbuffer().as<uint8_t*>() +
k * in_channels * convolution._kernel[X_AXIS] * convolution.precision.size(); k * A * B * convolution.precision.size();
auto transposedPart = transposeMatrix(ptr_filt_current, convolution.precision.size(), in_channels, convolution._kernel[X_AXIS]); auto transposedPart = transposeMatrix(ptr_filt_current, convolution.precision.size(), A, B);
transposedWeights.insert(transposedWeights.end(), transposedPart.begin(), transposedPart.end()); transposedWeights.insert(transposedWeights.end(), transposedPart.begin(), transposedPart.end());
} }
if (transposedWeights.size() != convolution._weights->byteSize()) { if (transposedWeights.size() != convolution._weights->byteSize()) {

View File

@ -128,8 +128,8 @@ public:
void finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerPtr, void finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerPtr,
uint32_t in_batch, uint32_t in_channels, uint32_t in_width, uint32_t in_batch, uint32_t in_channels, uint32_t in_width,
uint32_t out_batch, uint32_t out_channels, uint32_t out_width); uint32_t out_batch, uint32_t out_channels, uint32_t out_width,
uint32_t in_kernel_x, uint32_t in_kernel_y, bool transpose);
void finalizeConvolution2DPrimitive(InferenceEngine::CNNLayerPtr, void finalizeConvolution2DPrimitive(InferenceEngine::CNNLayerPtr,
uint32_t in_batch, uint32_t in_channels, uint32_t in_height, uint32_t in_width, uint32_t in_batch, uint32_t in_channels, uint32_t in_height, uint32_t in_width,
uint32_t out_batch, uint32_t out_channels, uint32_t out_height, uint32_t out_width); uint32_t out_batch, uint32_t out_channels, uint32_t out_height, uint32_t out_width);

View File

@ -89,7 +89,8 @@ inline std::pair<InferenceEngine::CNNLayerPtr, InferenceEngine::CNNLayerPtr> Fin
auto next = getInputTo(layer->outData.front()).begin()->second; auto next = getInputTo(layer->outData.front()).begin()->second;
// Permute is inserted before Reshape by MO in NHWC models, so we need to find either permute, or reshape, or output // Permute is inserted before Reshape by MO in NHWC models, so we need to find either permute, or reshape, or output
while (!LayerInfo(next).isPermute() && !LayerInfo(next).isOutput() && next->outData.size() == 1) { while (!LayerInfo(next).isPermute() && !LayerInfo(next).isPermuteViaReshape() &&
!LayerInfo(next).isOutput() && next->outData.size() == 1) {
if (LayerInfo(next).isNonFunctional() && !IsReshapeFrom4dTo3d(next) && !IsReshapeFrom3dTo4d(next)) { if (LayerInfo(next).isNonFunctional() && !IsReshapeFrom4dTo3d(next) && !IsReshapeFrom3dTo4d(next)) {
break; break;
} }
@ -111,14 +112,27 @@ inline std::pair<InferenceEngine::CNNLayerPtr, InferenceEngine::CNNLayerPtr> Fin
if (next->outData.size() != 1) { if (next->outData.size() != 1) {
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
// Check if reshape is expected for this pattern: auto input_dims = next->insData[0].lock()->getDims();
// the next layer has the both, height and width dimensions > 1 auto in_dims_size = input_dims.size();
auto in_dim_size = next->insData[0].lock()->getDims().size(); auto output_dims = next->outData[0]->getDims();
IE_ASSERT(in_dim_size == 3 || in_dim_size == 4); auto out_dims_size = output_dims.size();
size_t height = in_dim_size == 3 ? 1 : GetDataDimSize(next->insData[0].lock(), InferenceEngine::DataDimName::H); if (in_dims_size == 4 && out_dims_size == 4) {
size_t width = GetDataDimSize(next->insData[0].lock(), InferenceEngine::DataDimName::W); if (!LayerInfo(next).isPermuteViaReshape() ||
if (next->outData[0]->getDims().size() < 3 || height != 1 || width != 1) { (input_dims[0] != output_dims[0]) || // N
return std::make_pair(nullptr, nullptr); (input_dims[1] != output_dims[3]) || // C
(input_dims[2] != output_dims[1]) || // H
(input_dims[3] != output_dims[2])) { // W
return std::make_pair(nullptr, nullptr);
}
} else {
// Check if reshape is expected for this pattern:
// the next layer has the both, height and width dimensions > 1
IE_ASSERT(in_dims_size == 3 || in_dims_size == 4);
size_t height = in_dims_size == 3 ? 1 : GetDataDimSize(next->insData[0].lock(), InferenceEngine::DataDimName::H);
size_t width = GetDataDimSize(next->insData[0].lock(), InferenceEngine::DataDimName::W);
if (out_dims_size < 3 || height != 1 || width != 1) {
return std::make_pair(nullptr, nullptr);
}
} }
} else { } else {
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
@ -127,7 +141,8 @@ inline std::pair<InferenceEngine::CNNLayerPtr, InferenceEngine::CNNLayerPtr> Fin
// Permute is inserted after Reshape by MO in NHWC models, so we need to find either permute, or reshape, or input // Permute is inserted after Reshape by MO in NHWC models, so we need to find either permute, or reshape, or input
auto parent = InferenceEngine::CNNNetPrevLayer(layer); auto parent = InferenceEngine::CNNNetPrevLayer(layer);
auto prev = parent; auto prev = parent;
while (!LayerInfo(prev).isPermute() && !LayerInfo(prev).isInput() && InferenceEngine::CNNNetHasPrevLayer(prev.get())) { while (!LayerInfo(prev).isPermute() && !LayerInfo(prev).isPermuteViaReshape() &&
!LayerInfo(prev).isInput() && InferenceEngine::CNNNetHasPrevLayer(prev.get())) {
if (LayerInfo(prev).isNonFunctional() && !IsReshapeFrom4dTo3d(prev) && !IsReshapeFrom3dTo4d(prev)) { if (LayerInfo(prev).isNonFunctional() && !IsReshapeFrom4dTo3d(prev) && !IsReshapeFrom3dTo4d(prev)) {
break; break;
} }
@ -142,19 +157,35 @@ inline std::pair<InferenceEngine::CNNLayerPtr, InferenceEngine::CNNLayerPtr> Fin
order != std::vector<int32_t>{0, 2, 1} /* NWC to NCW */) { order != std::vector<int32_t>{0, 2, 1} /* NWC to NCW */) {
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
} else if (LayerInfo(prev).isReshape()) { } else if (LayerInfo(prev).isReshape()) {
if (parent->outData.size() != 1 || InferenceEngine::getInputTo(parent->outData[0]).size() != 1) { auto input_dims = prev->insData[0].lock()->getDims();
return std::make_pair(nullptr, nullptr); auto in_dims_size = input_dims.size();
} auto output_dims = prev->outData[0]->getDims();
// Check if reshape is expected for this pattern: auto out_dims_size = output_dims.size();
// the previous layer has number of channels > 1 and one of height/width dimensions is also > 1
size_t out_dims_size = parent->outData[0]->getDims().size(); if (in_dims_size == 4 && out_dims_size == 4) {
IE_ASSERT(out_dims_size == 3 || out_dims_size == 4); if (!LayerInfo(prev).isPermuteViaReshape() ||
size_t channels = GetDataDimSize(parent->outData[0], out_dims_size - 1); (input_dims[0] != output_dims[0]) || // N
size_t height = out_dims_size == 3 ? 1 : GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::H); (input_dims[1] != output_dims[2]) || // H
size_t width = GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::W); (input_dims[2] != output_dims[3]) || // W
if (parent->insData[0].lock()->getDims().size() < 3 || channels != 1 && (height != 1 || width != 1)) { (input_dims[3] != output_dims[1])) { // C
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
}
} else {
if (parent->outData.size() != 1 || InferenceEngine::getInputTo(parent->outData[0]).size() != 1) {
return std::make_pair(nullptr, nullptr);
}
// Check if reshape is expected for this pattern:
// the previous layer has number of channels > 1 and one of height/width dimensions is also > 1
in_dims_size = parent->insData[0].lock()->getDims().size();
out_dims_size = parent->outData[0]->getDims().size();
IE_ASSERT(out_dims_size == 3 || out_dims_size == 4);
size_t channels = GetDataDimSize(parent->outData[0], out_dims_size - 1);
size_t height = out_dims_size == 3 ? 1 : GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::H);
size_t width = GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::W);
if (in_dims_size < 3 || channels != 1 && (height != 1 || width != 1)) {
return std::make_pair(nullptr, nullptr);
}
} }
} else { } else {
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);

View File

@ -79,6 +79,7 @@
#include "transformations/decompose_mvn.hpp" #include "transformations/decompose_mvn.hpp"
#include "transformations/substitute_softsign.hpp" #include "transformations/substitute_softsign.hpp"
#include "transformations/convert_precision.hpp" #include "transformations/convert_precision.hpp"
#include "transformations/unfuse_reshape_and_transpose.hpp"
#include <ngraph/opsets/opset7.hpp> #include <ngraph/opsets/opset7.hpp>
@ -698,6 +699,9 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SwapInputMatMul>(); manager.register_pass<SwapInputMatMul>();
manager.register_pass<HandleTransposesAroundMatMul>(); manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<InsertTransposeAfterConvOrPool>(); manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<Unfuse2dto4dReshapeAndTranspose>();
manager.register_pass<Unfuse4dto2dReshapeAndTranspose>();
manager.register_pass<RemoveExtraReshapes>();
manager.register_pass<ReorderActivationAndPooling>(); manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<RemoveSingleInputConcat>(); manager.register_pass<RemoveSingleInputConcat>();
manager.register_pass<SubstituteSoftsign>(); manager.register_pass<SubstituteSoftsign>();
@ -792,7 +796,10 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<SubstitutePReluPass>(); passes->registerPass<SubstitutePReluPass>();
passes->registerPass<ReorderMaxPoolPass>(); if (!isNgraphPassesUsed) {
passes->registerPass<ReorderMaxPoolPass>();
}
passes->registerPass<EltwiseSplitOverChannelsPass>(); passes->registerPass<EltwiseSplitOverChannelsPass>();
passes->registerPass<InsertSplitAligningFilterPass>(); passes->registerPass<InsertSplitAligningFilterPass>();

View File

@ -16,8 +16,13 @@
namespace GNAPluginNS { namespace GNAPluginNS {
namespace GNAConvolutionLayer { namespace GNAConvolutionLayer {
bool isMappableFrom2DTo1D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t kernelWidth, const uint32_t strideWidth) {
return inHeight > 1 && inWidth > 1 && inWidth == kernelWidth && strideWidth == 1; bool isMappableFrom2DTo1D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t in_channels,
const uint32_t kernelHeight, const uint32_t kernelWidth,
const uint32_t strideHeight, const uint32_t strideWidth) {
return ((inHeight > 1 && inWidth > 1) &&
((inWidth == kernelWidth && strideWidth == 1) ||
(inHeight == kernelHeight && strideHeight == 1 && in_channels == 1)));
} }
// 3D input or 2D kernel // 3D input or 2D kernel
@ -39,7 +44,7 @@ double getWeightsReducer(InferenceEngine::ConvolutionLayer& conv) {
const auto inHeight = GetDataDimSize(conv.insData.front().lock(), InferenceEngine::DataDimName::H); const auto inHeight = GetDataDimSize(conv.insData.front().lock(), InferenceEngine::DataDimName::H);
const auto inWidth = GetDataDimSize(conv.insData.front().lock(), InferenceEngine::DataDimName::W); const auto inWidth = GetDataDimSize(conv.insData.front().lock(), InferenceEngine::DataDimName::W);
if (isConv2D(inHeight, inWidth, inDepth, conv._kernel_y, conv._kernel_x) && if (isConv2D(inHeight, inWidth, inDepth, conv._kernel_y, conv._kernel_x) &&
!isMappableFrom2DTo1D(inHeight, inWidth, conv._kernel_x, conv._stride_x)) { !isMappableFrom2DTo1D(inHeight, inWidth, inDepth, conv._kernel_y, conv._kernel_x, conv._stride_y, conv._stride_x)) {
const auto kernelSize = conv._kernel_x * conv._kernel_y; const auto kernelSize = conv._kernel_x * conv._kernel_y;
auto r = std::lower_bound(reducers.begin(), reducers.end(), kernelSize, auto r = std::lower_bound(reducers.begin(), reducers.end(), kernelSize,
[](const KRT& l, const KRT::first_type& r) {return l.first > r; }); [](const KRT& l, const KRT::first_type& r) {return l.first > r; });

View File

@ -10,7 +10,9 @@
namespace GNAPluginNS { namespace GNAPluginNS {
namespace GNAConvolutionLayer { namespace GNAConvolutionLayer {
bool isMappableFrom2DTo1D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t kernelWidth, const uint32_t strideWidth); bool isMappableFrom2DTo1D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t inChannels,
const uint32_t kernelHeight, const uint32_t kernelWidth,
const uint32_t strideHeight, const uint32_t strideWidth);
// 3D input or 2D kernel // 3D input or 2D kernel
bool isConv2D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t inDepth, bool isConv2D(const uint32_t inHeight, const uint32_t inWidth, const uint32_t inDepth,

View File

@ -271,6 +271,24 @@ class LayerInfo {
bool isPermute() const noexcept { bool isPermute() const noexcept {
return isOfType("permute"); return isOfType("permute");
} }
bool isPermuteViaReshape() const {
if (!isOfType("reshape")) return false;
auto input_dims = layer->insData[0].lock()->getDims();
auto output_dims = layer->outData[0]->getDims();
if (input_dims.size() != output_dims.size()) {
return false;
}
input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 1), input_dims.end());
output_dims.erase(std::remove(output_dims.begin(), output_dims.end(), 1), output_dims.end());
if (input_dims != output_dims) {
return false;
}
return true;
}
// @brief this not only mathematically trivial, has some WA for kaldi case // @brief this not only mathematically trivial, has some WA for kaldi case
bool isTrivialPermute() const { bool isTrivialPermute() const {
if (!isPermute()) return false; if (!isPermute()) return false;

View File

@ -633,11 +633,11 @@ void RemovePermutationsNHWCToNCHWPass::run() {
if (prev == nullptr || next == nullptr) continue; if (prev == nullptr || next == nullptr) continue;
if (LayerInfo(prev).isPermute()) { if (LayerInfo(prev).isPermute() || LayerInfo(prev).isPermuteViaReshape()) {
permutations_to_remove.insert(prev); permutations_to_remove.insert(prev);
} }
if (LayerInfo(next).isPermute()) { if (LayerInfo(next).isPermute() || LayerInfo(prev).isPermuteViaReshape()) {
permutations_to_remove.insert(next); permutations_to_remove.insert(next);
} }
@ -699,7 +699,8 @@ void RemovePermutationsNHWCToNCHWPass::run() {
}; };
propogateNHWCOrderRecursive(current_layer); propogateNHWCOrderRecursive(current_layer);
if (LayerInfo(pattern_start).isPermute() && !getInputTo(pattern_start->outData.front()).empty()) { if ((LayerInfo(pattern_start).isPermute() || LayerInfo(pattern_start).isPermuteViaReshape()) &&
!getInputTo(pattern_start->outData.front()).empty()) {
auto layer_before_permute = CNNNetPrevLayer(pattern_start); auto layer_before_permute = CNNNetPrevLayer(pattern_start);
DataPtr output = nullptr; DataPtr output = nullptr;
for (auto before_output : layer_before_permute->outData) { for (auto before_output : layer_before_permute->outData) {
@ -2017,11 +2018,11 @@ void MoveFakeQuantizeLayerIntoQuantParamsPass :: run() {
}; };
auto allowFQFuse = [](CNNLayerPtr layer) -> bool { auto allowFQFuse = [](CNNLayerPtr layer) -> bool {
auto doNotSkup = [](CNNLayerPtr layer) { auto doNotSkip = [](CNNLayerPtr layer) {
return false; return false;
}; };
if (CNNNetGetAllNextLayersSkipCertain(layer, -1, doNotSkup).empty()) { if (CNNNetGetAllNextLayersSkipCertain(layer, -1, doNotSkip).empty()) {
return false; return false;
} }
@ -2142,7 +2143,7 @@ void MoveFakeQuantizeLayerIntoQuantParamsPass :: run() {
// Before FQ layer is removed, the previous functional layer has to be updated with its quantization data // Before FQ layer is removed, the previous functional layer has to be updated with its quantization data
auto prevFuncLayer = CNNNetPrevLayerSkipCertain(*fqLayer, 0, [](CNNLayerPtr layer) { auto prevFuncLayer = CNNNetPrevLayerSkipCertain(*fqLayer, 0, [](CNNLayerPtr layer) {
return LayerInfo(layer).isNonFunctional(); return LayerInfo(layer).isNonFunctional() || LayerInfo(layer).isPooling();
}); });
auto quantParamsPrevLayer = InferenceEngine::getInjectedData<QuantizedLayerParams>(prevFuncLayer); auto quantParamsPrevLayer = InferenceEngine::getInjectedData<QuantizedLayerParams>(prevFuncLayer);
quantParamsPrevLayer->_dst_quant.SetLevels(fqLevels); quantParamsPrevLayer->_dst_quant.SetLevels(fqLevels);

View File

@ -128,7 +128,9 @@ static bool ShouldDecompose(GraphData& graph_data, const ConvData& conv_data) {
// GNA supported features or handled otherwise - there is no need to decompose such convolution // GNA supported features or handled otherwise - there is no need to decompose such convolution
if (graph_data.conv_count == 1 && (((conv_data.input_height == 1 || conv_data.input_width == 1) && if (graph_data.conv_count == 1 && (((conv_data.input_height == 1 || conv_data.input_width == 1) &&
conv_data.filter_dilation_width == 1 && conv_data.filter_dilation_height == 1) || conv_data.filter_dilation_width == 1 && conv_data.filter_dilation_height == 1) ||
GNAConvolutionLayer::isMappableFrom2DTo1D(conv_data.input_height, conv_data.input_width, conv_data.filter_width, conv_data.filter_stride_width))) GNAConvolutionLayer::isMappableFrom2DTo1D(conv_data.input_height, conv_data.input_width, conv_data.input_channel_count,
conv_data.filter_height, conv_data.filter_width,
conv_data.filter_stride_height, conv_data.filter_stride_width)))
return false; return false;
return true; return true;

View File

@ -8,6 +8,7 @@
#include <ngraph/opsets/opset7.hpp> #include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
using namespace GNAPluginNS; using namespace GNAPluginNS;
@ -15,16 +16,15 @@ NGRAPH_RTTI_DEFINITION(RemoveExtraReshapes, "RemoveExtraReshapes", 0);
RemoveExtraReshapes::RemoveExtraReshapes() { RemoveExtraReshapes::RemoveExtraReshapes() {
MATCHER_SCOPE(RemoveExtraReshapes); MATCHER_SCOPE(RemoveExtraReshapes);
const auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>(); const auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>(
[](const ngraph::Output<ngraph::Node>& value) {
return (value.get_node_shared_ptr()->get_input_shape(0) == value.get_node_shared_ptr()->get_output_shape(0));
});
const auto pooling = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({reshape}); const auto pooling = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({reshape});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map(); const auto& pattern_map = m.get_pattern_value_map();
const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (reshape_node->get_input_shape(0) != reshape_node->get_output_shape(0)) {
return false;
}
ngraph::replace_output_update_name(reshape_node->output(0), reshape_node->input_value(0)); ngraph::replace_output_update_name(reshape_node->output(0), reshape_node->input_value(0));
return true; return true;
}; };

View File

@ -33,9 +33,10 @@ static bool shouldSplitCnn(const ngraph::Output<ngraph::Node>& node) {
uint32_t height = input.at(2); uint32_t height = input.at(2);
auto kH = filters.at(2); auto kH = filters.at(2);
auto kW = filters.at(3); auto kW = filters.at(3);
auto sH = convolution->get_strides().at(0);
auto sW = convolution->get_strides().at(1); auto sW = convolution->get_strides().at(1);
if (GNAConvolutionLayer::isConv2D(height, width, in_channels, kH, kW) && if (GNAConvolutionLayer::isConv2D(height, width, in_channels, kH, kW) &&
!GNAConvolutionLayer::isMappableFrom2DTo1D(height, width, kW, sW)) { !GNAConvolutionLayer::isMappableFrom2DTo1D(height, width, in_channels, kH, kW, sH, sW)) {
return false; return false;
} }
} }

View File

@ -0,0 +1,189 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/unfuse_reshape_and_transpose.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations/utils/transformation_helper.hpp"
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(Unfuse2dto4dReshapeAndTranspose, "Unfuse2dto4dReshapeAndTranspose", 0);
Unfuse2dto4dReshapeAndTranspose::Unfuse2dto4dReshapeAndTranspose() {
MATCHER_SCOPE(Unfuse2dto4dReshapeAndTranspose);
auto is_required_reshape = [](const ngraph::Output<ngraph::Node>& value) {
auto input_shape = value.get_node_shared_ptr()->get_input_shape(0);
auto output_shape = value.get_node_shared_ptr()->get_output_shape(0);
return ((input_shape.size() == 2) && (output_shape.size() == 4) &&
((output_shape.at(1) == 1) || (output_shape.at(2)*output_shape.at(3) == 1)));
};
const auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(is_required_reshape);
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({reshape,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
const auto conv = ngraph::pattern::wrap_type<ngraph::opset8::Convolution>({std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, fq}),
ngraph::pattern::any_input()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
auto consumers = reshape_node->output(0).get_target_inputs();
auto N = reshape_node->get_output_shape(0)[0];
auto C = reshape_node->get_output_shape(0)[1];
auto H = reshape_node->get_output_shape(0)[2];
auto W = reshape_node->get_output_shape(0)[3];
// Create reshape NxW => NxHxWxC (C or HxW is equal to 1)
auto data = reshape_node->input_value(0);
auto reshape_nhwc_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{N, H, W, C});
auto reshape_nhwc = register_new_node<ngraph::opset8::Reshape>(data, reshape_nhwc_const, false);
reshape_nhwc->set_friendly_name(reshape_node->get_friendly_name() + "/Reshape");
// Create transpose NxHxWxC => NxCxHxW
auto transpose_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 3, 1, 2});
auto transpose = register_new_node<ngraph::opset8::Transpose>(reshape_nhwc, transpose_const);
transpose->set_friendly_name(reshape_node->get_friendly_name());
ngraph::copy_runtime_info(reshape, {reshape_nhwc, transpose});
for (auto consumer : consumers) {
consumer.replace_source_output(transpose);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(Unfuse4dto2dReshapeAndTranspose, "Unfuse4dto2dReshapeAndTranspose", 0);
Unfuse4dto2dReshapeAndTranspose::Unfuse4dto2dReshapeAndTranspose() {
MATCHER_SCOPE(Unfuse4dto2dReshapeAndTranspose);
auto is_required_reshape = [](const ngraph::Output<ngraph::Node>& value) {
auto input_shape = value.get_node_shared_ptr()->get_input_shape(0);
auto output_shape = value.get_node_shared_ptr()->get_output_shape(0);
return ((input_shape.size() == 4) && (output_shape.size() == 2) &&
((input_shape.at(1) == 1) || (input_shape.at(2)*input_shape.at(3) == 1)));
};
// Convolution
auto conv = ngraph::pattern::wrap_type<ngraph::opset8::Convolution>({ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({conv,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
// Bias
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Add>({conv, ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({bias,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
// Max Pooling
auto max_pool_conv = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({conv},
consumers_and_rank(1, 4));
auto max_pool_fq_conv = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({fq_conv},
consumers_and_rank(1, 4));
auto max_pool_bias = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({bias},
consumers_and_rank(1, 4));
auto max_pool_fq_bias = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({fq_bias},
consumers_and_rank(1, 4));
// Activation
auto fq_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_conv,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_bias,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto act_conv = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({conv},
consumers_and_rank(1, 4));
auto act_bias = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({bias},
consumers_and_rank(1, 4));
auto act_max_pool_conv = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({max_pool_conv},
consumers_and_rank(1, 4));
auto act_max_pool_bias = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({max_pool_bias},
consumers_and_rank(1, 4));
auto act_fq_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({fq_fq_conv},
consumers_and_rank(1, 4));
auto act_fq_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({fq_fq_bias},
consumers_and_rank(1, 4));
auto fq_max_pool_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({max_pool_fq_conv,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto act_fq_max_pool_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({fq_max_pool_fq_conv},
consumers_and_rank(1, 4));
auto fq_max_pool_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({max_pool_fq_bias,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto act_fq_max_pool_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({fq_max_pool_fq_bias},
consumers_and_rank(1, 4));
auto fq_act_fq_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({act_fq_fq_conv,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_act_fq_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({act_fq_fq_bias,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_act_fq_max_pool_fq_conv = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({act_fq_max_pool_fq_conv,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto fq_act_fq_max_pool_fq_bias = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({act_fq_max_pool_fq_bias,
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
consumers_and_rank(1, 4));
auto root_reshape =
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{conv, bias, max_pool_conv, max_pool_fq_conv, max_pool_bias, max_pool_fq_bias,
fq_conv, fq_bias, act_conv, act_bias, act_max_pool_conv, act_max_pool_bias,
fq_act_fq_fq_conv, fq_act_fq_fq_bias, fq_act_fq_max_pool_fq_conv, fq_act_fq_max_pool_fq_bias});
const auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({root_reshape, ngraph::pattern::any_input()}, is_required_reshape);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
auto consumers = reshape_node->output(0).get_target_inputs();
auto N = reshape_node->get_input_shape(0)[0];
auto W = reshape_node->get_input_shape(0)[1]*reshape_node->get_input_shape(0)[2]*reshape_node->get_input_shape(0)[3];
// Create transpose NxCxHxW => NxHxWxC
auto data = reshape_node->input_value(0);
auto transpose_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 2, 3, 1});
auto transpose = register_new_node<ngraph::opset8::Transpose>(data, transpose_const);
transpose->set_friendly_name(reshape_node->get_friendly_name() + "/Transpose");
// Create reshape NxHxWxC => NxW (C or HxW is equal to 1)
auto reshape_nw_const = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{N, W});
auto reshape_nw = register_new_node<ngraph::opset8::Reshape>(transpose, reshape_nw_const, false);
reshape_nw->set_friendly_name(reshape_node->get_friendly_name());
ngraph::copy_runtime_info(reshape_node, {transpose, reshape_nw});
for (auto consumer : consumers) {
consumer.replace_source_output(reshape_nw);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,72 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
/**
* @brief Replace 2d->4d reshape to pair of 2 reshapes (before Convolution)
* Before:
* [N, HW]
* |
* Reshape
* |
* [N, C, H, W]
* |
* Convolution
*
* After (TransposeSinking friendly):
* [N, HW]
* |
* Reshape
* |
* [N, H, W, C]
* |
* Reshape
* |
* [N, C, H, W]
* |
* Convolution
*/
class Unfuse2dto4dReshapeAndTranspose : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
Unfuse2dto4dReshapeAndTranspose();
};
/**
* @brief Replace 2d->4d reshape to pair of 2 reshapes (after Convolution)
* Before:
* Convolution (optionally + bias/pooling/activation)
* |
* [N, C, H, W]
* |
* Reshape
* |
* [N, HW]
*
* After (TransposeSinking friendly):
* Convolution
* |
* [N, C, H, W]
* |
* Reshape
* |
* [N, H, W, C]
* |
* Reshape
* |
* [N, HW]
*
*/
class Unfuse4dto2dReshapeAndTranspose : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
Unfuse4dto2dReshapeAndTranspose();
};
} // namespace GNAPluginNS

View File

@ -24,16 +24,20 @@ typedef std::tuple<
std::string, // Target Device std::string, // Target Device
std::map<std::string, std::string>, // Configuration std::map<std::string, std::string>, // Configuration
std::vector<size_t>, // Input Shape std::vector<size_t>, // Input Shape
std::pair<float, float>, // Input Min and Max std::pair<float, float>, // Input Min and Max (before conv)
size_t // Levels std::pair<float, float>, // Input Min and Max (after conv)
size_t, // Levels
bool // Reshape between FQ and Pooling
> fqMaxpoolReorderingParams; > fqMaxpoolReorderingParams;
namespace LayerTestsDefinitions { namespace LayerTestsDefinitions {
class FQMaxpoolReordering : public testing::WithParamInterface<fqMaxpoolReorderingParams>, class FQMaxpoolReordering : public testing::WithParamInterface<fqMaxpoolReorderingParams>,
public LayerTestsUtils::LayerTestsCommon { public LayerTestsUtils::LayerTestsCommon {
float inputDataMin = 0.0f; float inputDataMin1 = 0.0f;
float inputDataMax = 0.0f; float inputDataMax1 = 0.0f;
float inputDataMin2 = 0.0f;
float inputDataMax2 = 0.0f;
float inputDataResolution = 1.0f; float inputDataResolution = 1.0f;
public: public:
@ -42,9 +46,11 @@ public:
std::string targetDevice; std::string targetDevice;
std::map<std::string, std::string> configuration; std::map<std::string, std::string> configuration;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
std::pair<float, float> inputMinMax; std::pair<float, float> inputMinMax1;
std::pair<float, float> inputMinMax2;
size_t levels = 0; size_t levels = 0;
std::tie(netPrecision, targetDevice, configuration, inputShape, inputMinMax, levels) = obj.param; bool reshape = false;
std::tie(netPrecision, targetDevice, configuration, inputShape, inputMinMax1, inputMinMax2, levels, reshape) = obj.param;
std::ostringstream result; std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_"; result << "netPRC=" << netPrecision.name() << "_";
@ -53,14 +59,16 @@ public:
result << "_configItem=" << configItem.first << "_" << configItem.second; result << "_configItem=" << configItem.first << "_" << configItem.second;
} }
result << "_inputShape=" << CommonTestUtils::vec2str(inputShape); result << "_inputShape=" << CommonTestUtils::vec2str(inputShape);
result << "_inputMinMax=(" << inputMinMax.first << ".." << inputMinMax.second << ")"; result << "_inputMinMax1=(" << inputMinMax1.first << ".." << inputMinMax1.second << ")";
result << "_inputMinMax2=(" << inputMinMax2.first << ".." << inputMinMax2.second << ")";
result << "_levels=" << levels; result << "_levels=" << levels;
result << "_reshape=" << reshape;
return result.str(); return result.str();
} }
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override { InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override {
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), inputDataMax - inputDataMin, inputDataMin, 1 / inputDataResolution); return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), inputDataMax1 - inputDataMin1, inputDataMin1, 1 / inputDataResolution);
} }
protected: protected:
@ -68,23 +76,28 @@ protected:
InferenceEngine::Precision netPrecision; InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
std::pair<float, float> inputMinMax; std::pair<float, float> inputMinMax1;
std::pair<float, float> inputMinMax2;
size_t levels = 0; size_t levels = 0;
std::tie(netPrecision, targetDevice, configuration, inputShape, inputMinMax, levels) = this->GetParam(); bool reshape = false;
std::tie(netPrecision, targetDevice, configuration, inputShape, inputMinMax1, inputMinMax2, levels, reshape) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
std::tie(inputDataMin, inputDataMax) = inputMinMax; std::tie(inputDataMin1, inputDataMax1) = inputMinMax1;
auto inputLowNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMin }); std::tie(inputDataMin2, inputDataMax2) = inputMinMax2;
auto inputHighNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMax }); auto inputLowNode1 = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMin1 });
auto inputHighNode1 = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMax1 });
auto inputLowNode2 = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMin2 });
auto inputHighNode2 = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMax2 });
auto inputVector = ngraph::builder::makeParams(ngPrc, {inputShape}); auto inputVector = ngraph::builder::makeParams(ngPrc, {inputShape});
auto inputFQ = std::make_shared<ngraph::opset1::FakeQuantize>(inputVector[0], auto inputFQ = std::make_shared<ngraph::opset1::FakeQuantize>(inputVector[0],
inputLowNode, inputHighNode, inputLowNode, inputHighNode, levels); inputLowNode1, inputHighNode1, inputLowNode1, inputHighNode1, levels);
auto filterWeightsNode = ngraph::builder::makeConstant<float>(ngPrc, {8, inputShape[1], 1, 8}, { 1.0f }); auto filterWeightsNode = ngraph::builder::makeConstant<float>(ngPrc, {8, inputShape[1], 1, 8}, { 1.0f });
auto convLowNode = ngraph::builder::makeConstant(ngraph::element::f32, std::vector<size_t>{ 1 }, std::vector<float>{inputDataMin}); auto convLowNode = ngraph::builder::makeConstant(ngraph::element::f32, std::vector<size_t>{ 1 }, std::vector<float>{inputDataMin1});
auto convHighNode = ngraph::builder::makeConstant(ngraph::element::f32, std::vector<size_t>{ 1 }, std::vector<float>{inputDataMax}); auto convHighNode = ngraph::builder::makeConstant(ngraph::element::f32, std::vector<size_t>{ 1 }, std::vector<float>{inputDataMax1});
auto convWeightsFQNode = std::make_shared<ngraph::opset1::FakeQuantize>(filterWeightsNode, auto convWeightsFQNode = std::make_shared<ngraph::opset1::FakeQuantize>(filterWeightsNode,
convLowNode, convHighNode, convLowNode, convHighNode, levels); convLowNode, convHighNode, convLowNode, convHighNode, levels);
auto convWeightsFQ = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(convWeightsFQNode); auto convWeightsFQ = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(convWeightsFQNode);
@ -97,9 +110,20 @@ protected:
auto add = std::make_shared<ngraph::opset1::Add>(conv, biasesWeightsNode); auto add = std::make_shared<ngraph::opset1::Add>(conv, biasesWeightsNode);
auto convFQNode = std::make_shared<ngraph::opset1::FakeQuantize>(add, auto convFQNode = std::make_shared<ngraph::opset1::FakeQuantize>(add,
inputLowNode, inputHighNode, inputLowNode, inputHighNode, levels); inputLowNode2, inputHighNode2, inputLowNode2, inputHighNode2, levels);
auto maxpool = ngraph::builder::makePooling(convFQNode, {1, 2}, {0, 0}, {0, 0}, {1, 2}, ngraph::op::RoundingType::FLOOR, std::shared_ptr<ngraph::Node> node_before_pooling = convFQNode;
if (reshape) {
const auto& shape = conv->get_output_shape(0);
size_t total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
auto reshapeConst1 = ngraph::builder::makeConstant(ngraph::element::i64, std::vector<size_t>{ 2 }, ngraph::Shape{1, total});
auto reshapeNode1 = std::make_shared<ngraph::opset1::Reshape>(convFQNode, reshapeConst1, false);
auto reshapeConst2 = ngraph::builder::makeConstant(ngraph::element::i64, std::vector<size_t>{ 4 }, shape);
auto reshapeNode2 = std::make_shared<ngraph::opset1::Reshape>(reshapeNode1, reshapeConst2, false);
node_before_pooling = reshapeNode2;
}
auto maxpool = ngraph::builder::makePooling(node_before_pooling, {1, 2}, {0, 0}, {0, 0}, {1, 2}, ngraph::op::RoundingType::FLOOR,
ngraph::op::PadType::VALID, false, ngraph::helpers::PoolingTypes::MAX); ngraph::op::PadType::VALID, false, ngraph::helpers::PoolingTypes::MAX);
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxpool)}; ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxpool)};
@ -130,7 +154,9 @@ const std::vector<std::vector<size_t>> inputShape = {
const std::vector<std::pair<float, float>> inputMinMax = { const std::vector<std::pair<float, float>> inputMinMax = {
{-0.5, 0.5}, {-0.5, 0.5},
{-2, 2}, {-2, 2},
{-8, 8} {-8, 8},
{-5, 5},
{-17.5, 17.5},
}; };
const std::vector<size_t> levels = { const std::vector<size_t> levels = {
@ -144,6 +170,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_fq_maxpool_reordering, FQMaxpoolReordering,
::testing::ValuesIn(configs), ::testing::ValuesIn(configs),
::testing::ValuesIn(inputShape), ::testing::ValuesIn(inputShape),
::testing::ValuesIn(inputMinMax), ::testing::ValuesIn(inputMinMax),
::testing::ValuesIn(levels)), ::testing::ValuesIn(inputMinMax),
::testing::ValuesIn(levels),
::testing::ValuesIn(std::vector<bool>{true, false})),
FQMaxpoolReordering::getTestCaseName); FQMaxpoolReordering::getTestCaseName);
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions

View File

@ -18,6 +18,7 @@
#include "ngraph_functions/builders.hpp" #include "ngraph_functions/builders.hpp"
#include "ngraph_functions/pass/convert_prc.hpp" #include "ngraph_functions/pass/convert_prc.hpp"
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
typedef std::tuple< typedef std::tuple<
InferenceEngine::Precision, // Network Precision InferenceEngine::Precision, // Network Precision
@ -31,7 +32,8 @@ typedef std::tuple<
std::string, // Target Device std::string, // Target Device
std::map<std::string, std::string>, // Configuration std::map<std::string, std::string>, // Configuration
std::vector<size_t>, // Input shape std::vector<size_t>, // Input shape
bool // additional bool parameter bool, // additional bool parameter
bool // transpose to reshape
> removePermutationsAddParamPassParams; > removePermutationsAddParamPassParams;
namespace LayerTestsDefinitions { namespace LayerTestsDefinitions {
@ -106,7 +108,8 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
std::map<std::string, std::string> configuration; std::map<std::string, std::string> configuration;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
bool output1D; bool output1D;
std::tie(netPrecision, targetDevice, configuration, inputShape, output1D) = obj.param; bool transpose_to_reshape;
std::tie(netPrecision, targetDevice, configuration, inputShape, output1D, transpose_to_reshape) = obj.param;
std::ostringstream result; std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_"; result << "netPRC=" << netPrecision.name() << "_";
@ -116,6 +119,7 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
} }
result << "_IS=" << CommonTestUtils::vec2str(inputShape); result << "_IS=" << CommonTestUtils::vec2str(inputShape);
result << "_1d_out=" << output1D; result << "_1d_out=" << output1D;
result << "_transpose2reshape=" << transpose_to_reshape;
return result.str(); return result.str();
} }
@ -133,7 +137,8 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
InferenceEngine::Precision netPrecision; InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
bool output1D; bool output1D;
std::tie(netPrecision, targetDevice, configuration, inputShape, output1D) = this->GetParam(); bool transpose_to_reshape;
std::tie(netPrecision, targetDevice, configuration, inputShape, output1D, transpose_to_reshape) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
size_t shape_size = inputShape.size(); size_t shape_size = inputShape.size();
@ -158,6 +163,11 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape2) }; ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape2) };
function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationsTest"); function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationsTest");
if (transpose_to_reshape) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::TransposeToReshape>();
manager.run_passes(function);
}
} }
}; };
@ -212,7 +222,8 @@ class RemovePermutationsWithPoolAndActTest : public testing::WithParamInterface<
std::map<std::string, std::string> configuration; std::map<std::string, std::string> configuration;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
bool withActivation; bool withActivation;
std::tie(netPrecision, targetDevice, configuration, inputShape, withActivation) = obj.param; bool transpose_to_reshape;
std::tie(netPrecision, targetDevice, configuration, inputShape, withActivation, transpose_to_reshape) = obj.param;
std::ostringstream result; std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_"; result << "netPRC=" << netPrecision.name() << "_";
@ -222,6 +233,7 @@ class RemovePermutationsWithPoolAndActTest : public testing::WithParamInterface<
} }
result << "_IS=" << CommonTestUtils::vec2str(inputShape); result << "_IS=" << CommonTestUtils::vec2str(inputShape);
result << "_withActivation=" << withActivation; result << "_withActivation=" << withActivation;
result << "_transpose2reshape=" << transpose_to_reshape;
return result.str(); return result.str();
} }
@ -255,7 +267,8 @@ class RemovePermutationsWithPoolAndActTest : public testing::WithParamInterface<
InferenceEngine::Precision netPrecision; InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape; std::vector<size_t> inputShape;
bool withActivation; bool withActivation;
std::tie(netPrecision, targetDevice, configuration, inputShape, withActivation) = this->GetParam(); bool transpose_to_reshape;
std::tie(netPrecision, targetDevice, configuration, inputShape, withActivation, transpose_to_reshape) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
size_t shape_size = inputShape.size(); size_t shape_size = inputShape.size();
@ -280,6 +293,12 @@ class RemovePermutationsWithPoolAndActTest : public testing::WithParamInterface<
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape2) }; ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape2) };
function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationsWithPoolAndActTest"); function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationsWithPoolAndActTest");
if (transpose_to_reshape) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::TransposeToReshape>();
manager.run_passes(function);
}
} }
}; };
@ -512,7 +531,8 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
::testing::Values(CommonTestUtils::DEVICE_GNA), ::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs), ::testing::ValuesIn(configs),
::testing::ValuesIn(inputShapes), ::testing::ValuesIn(inputShapes),
::testing::ValuesIn(std::vector<bool>{false, true})), // with 1d output of convolution ::testing::ValuesIn(std::vector<bool>{false, true}), // with 1d output of convolution
::testing::ValuesIn(std::vector<bool>{false, true})),// transpose to reshape
RemovePermutationsNHWCToNCHWPassTest::getTestCaseName); RemovePermutationsNHWCToNCHWPassTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_PermutationPass, RemovePermutationsNHWCToNCHWPassNoReshapesTest, INSTANTIATE_TEST_SUITE_P(smoke_PermutationPass, RemovePermutationsNHWCToNCHWPassNoReshapesTest,
@ -529,7 +549,8 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
::testing::Values(CommonTestUtils::DEVICE_GNA), ::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs), ::testing::ValuesIn(configs),
::testing::ValuesIn(inputShapes), ::testing::ValuesIn(inputShapes),
::testing::ValuesIn(std::vector<bool>{false, true})), // with activation ::testing::ValuesIn(std::vector<bool>{false, true}), // with activation
::testing::ValuesIn(std::vector<bool>{false, true})),// transpose to reshape
RemovePermutationsWithPoolAndActTest::getTestCaseName); RemovePermutationsWithPoolAndActTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_PermutationPass, RemovePermutationsWithTwoConvTest, INSTANTIATE_TEST_SUITE_P(smoke_PermutationPass, RemovePermutationsWithTwoConvTest,