[GNA] Fix Transpose-Conv-Transpose pattern recognition for 2d convolution (#6715)
This commit is contained in:
parent
7d85d61083
commit
e7a00e9b31
@ -65,9 +65,11 @@ inline std::pair<InferenceEngine::CNNLayerPtr, InferenceEngine::CNNLayerPtr> Fin
|
||||
if (parent->outData.size() != 1 || InferenceEngine::getInputTo(parent->outData[0]).size() != 1) {
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
}
|
||||
auto parent_dims = parent->outData[0]->getDims();
|
||||
// Check if the previous layer has all dimensions except one to be equal to 1
|
||||
if (std::count_if(std::begin(parent_dims), std::end(parent_dims), [](size_t dim) { return dim != 1; }) > 1) {
|
||||
// Check if reshape is expected for this pattern:
|
||||
// the previous layer has number of channels > 1 and one of height/width dimensions is also > 1
|
||||
if (GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::C) != 1 &&
|
||||
(GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::H) != 1 ||
|
||||
GetDataDimSize(parent->outData[0], InferenceEngine::DataDimName::W) != 1)) {
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
@ -36,6 +36,11 @@ typedef std::tuple<
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::vector<size_t> GetKernelShape(size_t height, size_t width, size_t kernel_size) {
|
||||
return (height == 1 ? std::vector<size_t>{1, kernel_size} :
|
||||
(width == 1 ? std::vector<size_t>{kernel_size, 1} : std::vector<size_t>{kernel_size, kernel_size}));
|
||||
}
|
||||
|
||||
class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<removePermutationsPassParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
@ -82,16 +87,15 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
size_t num_out_channels = 12;
|
||||
size_t kernel_size = 8;
|
||||
std::vector<size_t> kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, kernel_size} : std::vector<size_t>{kernel_size, 1});
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto kernel_shape = GetKernelShape(inputShape[1], inputShape[2], 8);
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernel_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels);
|
||||
|
||||
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(conv1,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
|
||||
|
||||
size_t out_width = (inputShape[2] - kernal_shape[1]) + 1;
|
||||
size_t out_height = (inputShape[1] - kernal_shape[0]) + 1;
|
||||
size_t out_width = (inputShape[2] - kernel_shape[1]) + 1;
|
||||
size_t out_height = (inputShape[1] - kernel_shape[0]) + 1;
|
||||
std::vector<size_t> outFormShapes = { 1, out_width * out_height * num_out_channels };
|
||||
auto pattern2 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, outFormShapes);
|
||||
auto reshape2 = std::make_shared<ngraph::opset1::Reshape>(permute2, pattern2, false);
|
||||
@ -132,9 +136,8 @@ protected:
|
||||
auto permute1 = std::make_shared<ngraph::opset1::Transpose>(params[0],
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
size_t kernal_size = 8;
|
||||
std::vector<size_t> kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, kernal_size} : std::vector<size_t>{kernal_size, 1});
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, 12);
|
||||
auto kernel_shape = GetKernelShape(inputShape[1], inputShape[2], 8);
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernel_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, 12);
|
||||
|
||||
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(conv1,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
|
||||
@ -209,18 +212,18 @@ class RemovePermutationsWithPoolAndActTest : public testing::WithParamInterface<
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
size_t num_out_channels = 12;
|
||||
size_t kernal_size = 8;
|
||||
auto kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, kernal_size} : std::vector<size_t>{kernal_size, 1});
|
||||
std::vector<float> filter_weights = CommonTestUtils::generate_float_numbers(num_out_channels * inputShape[3] * kernal_size,
|
||||
auto kernel_shape = GetKernelShape(inputShape[1], inputShape[2], 8);
|
||||
std::vector<float> filter_weights = CommonTestUtils::generate_float_numbers(num_out_channels * inputShape[3] *
|
||||
kernel_shape[0] * kernel_shape[1],
|
||||
-0.2f, 0.2f);
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernel_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels, false, filter_weights);
|
||||
auto pool_kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, 2} : std::vector<size_t>{2, 1});
|
||||
auto pool = ngraph::builder::makePooling(conv1, pool_kernal_shape, {0, 0}, {0, 0}, pool_kernal_shape, ngraph::op::RoundingType::FLOOR,
|
||||
ngraph::op::PadType::VALID, false, ngraph::helpers::PoolingTypes::MAX);
|
||||
|
||||
size_t out_width = ((inputShape[2] - kernal_shape[1]) + 1) / pool_kernal_shape[1];
|
||||
size_t out_height = ((inputShape[1] - kernal_shape[0]) + 1) / pool_kernal_shape[0];
|
||||
size_t out_width = ((inputShape[2] - kernel_shape[1]) + 1) / pool_kernal_shape[1];
|
||||
size_t out_height = ((inputShape[1] - kernel_shape[0]) + 1) / pool_kernal_shape[0];
|
||||
|
||||
auto pool_output = pool;
|
||||
if (withActivation) {
|
||||
@ -299,21 +302,24 @@ class RemovePermutationsWithTwoConvTest : public testing::WithParamInterface<rem
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
size_t num_out_channels = 12;
|
||||
size_t kernal_size = 8;
|
||||
std::vector<size_t> kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, kernal_size} : std::vector<size_t>{kernal_size, 1});
|
||||
std::vector<float> filter_weights_1 = CommonTestUtils::generate_float_numbers(num_out_channels * inputShape[3] * kernal_size,
|
||||
size_t kernel_size = 8;
|
||||
auto kernel_shape1 = GetKernelShape(inputShape[1], inputShape[2], kernel_size);
|
||||
std::vector<float> filter_weights_1 = CommonTestUtils::generate_float_numbers(num_out_channels * inputShape[3] *
|
||||
kernel_shape1[0] * kernel_shape1[1],
|
||||
0.0f, 0.5f);
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernel_shape1, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels, false, filter_weights_1);
|
||||
size_t out_width = ((inputShape[2] - kernal_shape[1]) + 1);
|
||||
size_t out_height = ((inputShape[1] - kernal_shape[0]) + 1);
|
||||
size_t out_width = conv1->get_output_shape(0).at(3);
|
||||
size_t out_height = conv1->get_output_shape(0).at(2);
|
||||
|
||||
std::vector<float> filter_weights_2 = CommonTestUtils::generate_float_numbers(num_out_channels * num_out_channels * kernal_size,
|
||||
std::vector<size_t> kernel_shape2 = (out_height == 1 ? std::vector<size_t>{1, kernel_size} : std::vector<size_t>{kernel_size, 1});
|
||||
std::vector<float> filter_weights_2 = CommonTestUtils::generate_float_numbers(num_out_channels * num_out_channels *
|
||||
kernel_size,
|
||||
-0.2f, 0.2f);
|
||||
auto conv2 = ngraph::builder::makeConvolution(conv1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto conv2 = ngraph::builder::makeConvolution(conv1, ngPrc, kernel_shape2, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels, false, filter_weights_2);
|
||||
out_width = ((out_width - kernal_shape[1]) + 1);
|
||||
out_height = ((out_height - kernal_shape[0]) + 1);
|
||||
out_width = conv2->get_output_shape(0).at(3);
|
||||
out_height = conv2->get_output_shape(0).at(2);
|
||||
|
||||
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(conv2,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
|
||||
@ -391,11 +397,11 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
size_t num_out_channels = 12;
|
||||
size_t kernal_size = 8;
|
||||
std::vector<size_t> kernal_shape = (inputShape[1] == 1 ? std::vector<size_t>{1, kernal_size} : std::vector<size_t>{kernal_size, 1});
|
||||
std::vector<float> filter_weights_1 = CommonTestUtils::generate_float_numbers(num_out_channels * in_channels * kernal_size,
|
||||
auto kernel_shape = GetKernelShape(inputShape[1], inputShape[2], 8);
|
||||
std::vector<float> filter_weights_1 = CommonTestUtils::generate_float_numbers(num_out_channels * in_channels *
|
||||
kernel_shape[0] * kernel_shape[1],
|
||||
-0.2f, 0.2f);
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, kernel_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels, false, filter_weights_1);
|
||||
|
||||
auto pattern2 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 4 }, inputShape);
|
||||
@ -403,9 +409,10 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
|
||||
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(reshape2,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
std::vector<float> filter_weights_2 = CommonTestUtils::generate_float_numbers(num_out_channels * in_channels * kernal_size,
|
||||
std::vector<float> filter_weights_2 = CommonTestUtils::generate_float_numbers(num_out_channels * in_channels *
|
||||
kernel_shape[0] * kernel_shape[1],
|
||||
-0.2f, 0.2f);
|
||||
auto conv2 = ngraph::builder::makeConvolution(permute2, ngPrc, kernal_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
auto conv2 = ngraph::builder::makeConvolution(permute2, ngPrc, kernel_shape, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 },
|
||||
ngraph::op::PadType::VALID, num_out_channels, false, filter_weights_2);
|
||||
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(conv1, conv2);
|
||||
@ -413,8 +420,8 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
|
||||
auto permute3 = std::make_shared<ngraph::opset1::Transpose>(add,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
|
||||
|
||||
size_t out_width = ((in_width - kernal_shape[1]) + 1);
|
||||
size_t out_height = ((in_height - kernal_shape[0]) + 1);
|
||||
size_t out_width = ((in_width - kernel_shape[1]) + 1);
|
||||
size_t out_height = ((in_height - kernel_shape[0]) + 1);
|
||||
std::vector<size_t> outFormShapes = { 1, out_width * out_height * num_out_channels };
|
||||
auto pattern3 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, outFormShapes);
|
||||
auto reshape3 = std::make_shared<ngraph::opset1::Reshape>(permute3, pattern3, false);
|
||||
@ -468,7 +475,8 @@ class RemovePermutationsWithEltwiseTest : public testing::WithParamInterface<rem
|
||||
{1, 168, 1, 8},
|
||||
{1, 32, 1, 1},
|
||||
{1, 32, 1, 2},
|
||||
{1, 32, 1, 8}
|
||||
{1, 32, 1, 8},
|
||||
{1, 16, 8, 1}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_PermutationPass, RemovePermutationsNHWCToNCHWPassTest,
|
||||
|
Loading…
Reference in New Issue
Block a user