[LPT] NNCF GroupConvolution 5D on weights support (#16336)

* [LPT] NNCF GroupConvolution 5D on weights support

* PullReshapeThroughDequantization rollback
This commit is contained in:
Edward Shogulin 2023-03-23 13:24:10 +00:00 committed by GitHub
parent 8a246a8bf2
commit fb24e91416
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1171 additions and 595 deletions

View File

@ -237,8 +237,15 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
Shape newScaleShape = newScalePShape.to_shape();
if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
const auto input_shape = convolution->get_input_partial_shape(0);
const auto diff = newScaleShape.size() - input_shape.size();
OPENVINO_ASSERT(
newScaleShape.empty() || ((0 <= diff) && (diff <= 2ull)),
"unexpected shape size on weights");
for (size_t i = 0; i <= diff; ++i) {
newScaleShape.pop_back();
}
}
if (reshapeFromWeights != nullptr) {
@ -282,7 +289,12 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
const size_t weightsRankValue = weightsPShape.rank().get_length();
Shape zeroPointShape(weightsRankValue, 1ul);
// output channel or group
zeroPointShape[0] = static_cast<size_t>(weightsPShape[0].get_length());
if ((reshapeFromWeights == nullptr) && (weightsRankValue == 5ull)) {
// output channel
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
}
auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->input_value(1),

View File

@ -230,16 +230,16 @@ bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<cons
FakeQuantizeDequantization dequantizationOnWeights;
if (reshapeIsRequired) {
const auto reshape = layer->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Reshape>(reshape)) {
return false;
}
std::shared_ptr<Node> parent = ov::is_type<opset1::Reshape>(reshape) ?
reshape->get_input_node_shared_ptr(0) :
reshape;
if (ov::is_type<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0));
const auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
if (fq != nullptr) {
return NetworkHelper::isQuantizeSupported(fq);
}
dequantizationOnWeights = NetworkHelper::getDequantization(reshape, defaultPrecisions, 0);
dequantizationOnWeights = NetworkHelper::getDequantization(parent, defaultPrecisions, 0, true);
} else if (ov::is_type<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
return NetworkHelper::isQuantizeSupported(fq);

View File

@ -133,11 +133,17 @@ TEST_P(PullReshapeThroughDequantizationTransformation, CompareFunctions) {
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::Shape> inputShapes = {ngraph::Shape({1, 960, 7, 7}), ngraph::Shape({4, 960, 7, 7})};
// clang-format off
const std::vector<ngraph::Shape> inputShapes = {
ngraph::Shape({1, 960, 7, 7}),
ngraph::Shape({4, 960, 7, 7})
};
const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> dequantizationOnWeightElementwiseConstantShapes = {
{ngraph::Shape({1, 960}), ngraph::Shape({960, 1, 1, 1})},
{ngraph::Shape({9, 960}), ngraph::Shape({960, 1, 3, 3})}};
{ngraph::Shape({9, 960}), ngraph::Shape({960, 1, 3, 3})}
};
const std::vector<ngraph::Shape> multiplyShapes = {ngraph::Shape({1, 1, 960, 1})};
@ -193,37 +199,51 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
// \ /
// Multiply
//
{LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {9, 960}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.03f}, element::f32, {/* from parameter */}, false}},
{{3, 3, 960, 1}},
{{2}, element::f32, {/* from parameter: multiplyShapes */}, false},
{{2, 3, 0, 1}},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}},
// ExpectedValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 3, 3}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.06f}, element::f32, {/* from parameter */}, false}},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}}},
{
LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {9, 960}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false },
{ {0.03f}, element::f32, {/* from parameter */}, false }
},
{ {3, 3, 960, 1} },
{ {2}, element::f32, {/* from parameter: multiplyShapes */}, false },
{ {2, 3, 0, 1} },
{ {960, 1, 1, 3, 3} },
ngraph::element::f32,
{}
},
// ExpectedValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 3, 3}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false },
{ {0.06f}, element::f32, {/* from parameter */}, false }
},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
},
// Subtract with Convert + Constant
// Actual:
@ -276,37 +296,54 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
// \ /
// Multiply
//
{LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {9, 960}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false, 1ul, element::i8, true},
{{0.03f}, element::f32, {/* from parameter */}, false}},
{{3, 3, 960, 1}},
{{2}, element::f32, {/* from parameter: multiplyShapes */}, false},
{{2, 3, 0, 1}},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}},
// ExpectedValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 3, 3}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false, 1ul, element::i8, true},
{{0.06f}, element::f32, {/* from parameter */}, false}},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}}}};
{
LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {9, 960}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false, 1ul, element::i8, true },
{ {0.03f}, element::f32, {/* from parameter */}, false }
},
{ {3, 3, 960, 1} },
{ {2}, element::f32, {/* from parameter: multiplyShapes */}, false },
{ {2, 3, 0, 1} },
{ {960, 1, 1, 3, 3} },
ngraph::element::f32,
{}
},
// ExpectedValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 3, 3}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false, 1ul, element::i8, true },
{ {0.06f}, element::f32, {/* from parameter */}, false }
},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
}
};
// clang-format on
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
PullReshapeThroughDequantizationTransformation,

View File

@ -126,7 +126,12 @@ TEST_P(PullTransposeThroughDequantizationTransformation, CompareFunctions) {
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::Shape> inputShapes = {ngraph::Shape({1, 960, 7, 7}), ngraph::Shape({4, 960, 7, 7})};
// clang-format off
const std::vector<ngraph::Shape> inputShapes = {
ngraph::Shape({1, 960, 7, 7}),
ngraph::Shape({4, 960, 7, 7})
};
const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> dequantizationOnWeightElementwiseConstantShapes = {
{ngraph::Shape({}), ngraph::Shape({1, 1, 1, 1})},
@ -178,37 +183,54 @@ const std::vector<PullTransposeThroughDequantizationTestValues> testValues = {
// \ /
// Multiply
//
{LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {3, 3, 960, 1}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.03f}, element::f32, {/* from parameter */}, false}},
{}, // reshape1
{}, // multiply
{{2, 3, 0, 1}},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}},
// ExpectedValues
{ngraph::element::u8,
{{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 3, 3}},
{{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.03f}, element::f32, {/* from parameter */}, false}},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}}}};
{
LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{
ngraph::element::u8,
{
{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}
},
{std::vector<float>{2.f}, ngraph::element::i8, {3, 3, 960, 1}},
{
{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.03f}, element::f32, {/* from parameter */}, false}
},
{}, // reshape1
{}, // multiply
{{2, 3, 0, 1}},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
},
// ExpectedValues
{
ngraph::element::u8,
{
{ngraph::element::f32, false},
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}
},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 3, 3}},
{
{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
{{0.03f}, element::f32, {/* from parameter */}, false}
},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
}
};
// clang-format on
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
PullTransposeThroughDequantizationTransformation,

View File

@ -11,6 +11,8 @@
using namespace LayerTestsDefinitions;
namespace {
// clang-format off
const std::vector<ngraph::element::Type> netPrecisions = {
ngraph::element::f32,
// ngraph::element::f16
@ -370,6 +372,66 @@ const std::vector<LayerTestsDefinitions::GroupConvolutionQDqTransformationParam>
true,
},
// Actual:
//
// FQ
// |FP32
// |
// Convert Convert Constant Constant
// |U8 |U8 |I8 |I8
// | | | |
// Convert Convert Convert Convert
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Subtract Constant Subtract Constant
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Multiply Multiply
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
//
// Transformed:
//
// FQ Constant Constant
// \U8 /U8 / I8
// \ / /
// Subtract Subtract
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
{
{ 256ul, {{ 1, 1, 1, 1 }}, { -12.8f }, { 12.7f }, { 0.f }, { 255.f }, ngraph::element::f32 },
{ ngraph::element::u8, false },
{
{ ngraph::element::f32, false },
{ {128.f}, ngraph::element::f32, {}, false, 1ul, ngraph::element::u8, true },
{ {0.1f}, ngraph::element::f32, {}, false }
},
{ std::vector<float>(4, 15.f), ngraph::element::i8, {2, 1, 2, 1, 1} },
{},
{},
{
{ ngraph::element::f32, false },
{ {126.f, 127.f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false, 1ul, ngraph::element::i8, true },
{ {0.1f, 0.2f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false }
},
{},
"output_original",
"FP32",
true,
},
// Actual:
//
// FQ
@ -427,6 +489,63 @@ const std::vector<LayerTestsDefinitions::GroupConvolutionQDqTransformationParam>
false,
},
// Actual:
//
// FQ
// |FP32
// |
// Convert Convert
// |U8 |U8
// | |
// Convert Convert Constant
// \FP32 /FP32 \U8
// \ / \
// Subtract Constant Convert Constant
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Multiply Multiply
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution
//
// Transformed:
//
// FQ Constant
// \U8 /U8
// \ /
// Subtract
// \FP32
// \ Constant
// \ /I8
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
{
{ 256ul, {{ 1, 1, 1, 1 }}, { -12.8f }, { 12.7f }, { 0.f }, { 255.f }, ngraph::element::f32 },
{ ngraph::element::u8, false },
{
{ ngraph::element::f32, false },
{ {128.f}, ngraph::element::f32, {}, false, 1ul, ngraph::element::u8, true },
{ {0.1f}, ngraph::element::f32, {}, false }
},
{ std::vector<float>(4, 15.f), ngraph::element::i8, {2, 1, 2, 1, 1} },
{},
{},
{
{ ngraph::element::f32, false },
{},
{ {0.1f, 0.2f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false }
},
{},
"output_original",
"U8",
false,
},
// Actual:
//
// FQ
@ -500,4 +619,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, GroupConvolutionQDqTransformation,
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn(params)),
GroupConvolutionQDqTransformation::getTestCaseName);
// clang-format on
} // namespace

View File

@ -11,6 +11,8 @@
using namespace LayerTestsDefinitions;
namespace {
// clang-format off
const std::vector<ngraph::element::Type> netPrecisions = {
ngraph::element::f32,
// ngraph::element::f16
@ -370,6 +372,66 @@ const std::vector<LayerTestsDefinitions::GroupConvolutionQDqTransformationParam>
true,
},
// Actual:
//
// FQ
// |FP32
// |
// Convert Convert Constant Constant
// |U8 |U8 |I8 |I8
// | | | |
// Convert Convert Convert Convert
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Subtract Constant Subtract Constant
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Multiply Multiply
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
//
// Transformed:
//
// FQ Constant Constant
// \U8 /U8 / I8
// \ / /
// Subtract Subtract
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
{
{ 256ul, {{ 1, 1, 1, 1 }}, { -12.8f }, { 12.7f }, { 0.f }, { 255.f }, ngraph::element::f32 },
{ ngraph::element::u8, false },
{
{ ngraph::element::f32, false },
{ {128.f}, ngraph::element::f32, {}, false, 1ul, ngraph::element::u8, true },
{ {0.1f}, ngraph::element::f32, {}, false }
},
{ std::vector<float>(4, 15.f), ngraph::element::i8, {2, 1, 2, 1, 1} },
{},
{},
{
{ ngraph::element::f32, false },
{ {126.f, 127.f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false, 1ul, ngraph::element::i8, true },
{ {0.1f, 0.2f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false }
},
{},
"output_original",
"FP32",
true,
},
// Actual:
//
// FQ
@ -427,6 +489,63 @@ const std::vector<LayerTestsDefinitions::GroupConvolutionQDqTransformationParam>
false,
},
// Actual:
//
// FQ
// |FP32
// |
// Convert Convert
// |U8 |U8
// | |
// Convert Convert Constant
// \FP32 /FP32 \U8
// \ / \
// Subtract Constant Convert Constant
// \FP32 /FP32 \FP32 /FP32
// \ / \ /
// Multiply Multiply
// \FP32 /FP32
// \ /
// \ /
// \ /
// GroupConvolution
//
// Transformed:
//
// FQ Constant
// \U8 /U8
// \ /
// Subtract
// \FP32
// \ Constant
// \ /I8
// \ /
// GroupConvolution Constant
// \FP32 /FP32
// \ /
// Multiply
{
{ 256ul, {{ 1, 1, 1, 1 }}, { -12.8f }, { 12.7f }, { 0.f }, { 255.f }, ngraph::element::f32 },
{ ngraph::element::u8, false },
{
{ ngraph::element::f32, false },
{ {128.f}, ngraph::element::f32, {}, false, 1ul, ngraph::element::u8, true },
{ {0.1f}, ngraph::element::f32, {}, false }
},
{ std::vector<float>(4, 15.f), ngraph::element::i8, {2, 1, 2, 1, 1} },
{},
{},
{
{ ngraph::element::f32, false },
{},
{ {0.1f, 0.2f}, ngraph::element::f32, {2, 1, 1, 1, 1}, false }
},
{},
"output_original",
"U8",
false,
},
// Actual:
//
// FQ
@ -500,4 +619,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, GroupConvolutionQDqTransformation,
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn(params)),
GroupConvolutionQDqTransformation::getTestCaseName);
// clang-format on
} // namespace

View File

@ -49,7 +49,8 @@ public:
const ngraph::builder::subgraph::DequantizationOperations& dequantizationOnWeights,
const ngraph::element::Type precisionAfterOperation,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
const ngraph::element::Type precisionAfterDequantization);
const ngraph::element::Type precisionAfterDequantization,
const bool addReshape);
};
} // namespace subgraph

View File

@ -31,7 +31,8 @@ std::shared_ptr<Node> createWeightsOriginal(
const size_t kernelSize,
const std::vector<float>& weightsValues,
const FakeQuantizeOnWeights& fakeQuantizeOnWeights,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationOnWeights) {
const ngraph::builder::subgraph::DequantizationOperations& dequantizationOnWeights,
const bool addReshape = true) {
std::shared_ptr<Node> weights;
if (fakeQuantizeOnWeights.empty() && dequantizationOnWeights.empty()) {
weights = ngraph::opset1::Constant::create(
@ -46,9 +47,13 @@ std::shared_ptr<Node> createWeightsOriginal(
const size_t inputChannelsPerGroup = inputChannelsCount / groupCount;
weights = ngraph::opset1::Constant::create(
precision,
rankLength == 3 ?
ngraph::Shape{ outputChannelsCount, inputChannelsPerGroup, kernelSize } :
ngraph::Shape{ outputChannelsCount, inputChannelsPerGroup, kernelSize, kernelSize },
addReshape ?
(rankLength == 3 ?
ngraph::Shape{ outputChannelsCount, inputChannelsPerGroup, kernelSize } :
ngraph::Shape{ outputChannelsCount, inputChannelsPerGroup, kernelSize, kernelSize }) :
(rankLength == 3 ?
ngraph::Shape{ groupCount, outputChannelsCount / groupCount, inputChannelsPerGroup, kernelSize } :
ngraph::Shape{ groupCount, outputChannelsCount / groupCount, inputChannelsPerGroup, kernelSize, kernelSize }),
weightsValues.size() == 1ul ?
std::vector<float>(
rankLength == 3 ?
@ -75,24 +80,26 @@ std::shared_ptr<Node> createWeightsOriginal(
weights = ngraph::builder::subgraph::makeDequantization(weights, dequantizationOnWeights);
}
weights = std::make_shared<ngraph::opset1::Reshape>(
weights,
ngraph::opset1::Constant::create(
element::i64,
Shape{ static_cast<size_t>(rankLength) + 1ul },
rankLength == 3 ?
std::vector<int64_t> {
calculatedDimention == 0 ? -1 : static_cast<int64_t>(groupCount),
calculatedDimention == 1 ? -1 : static_cast<int64_t>(outputChannelsCount / groupCount),
static_cast<int64_t>(inputChannelsPerGroup),
static_cast<int64_t>(kernelSize) } :
std::vector<int64_t> {
calculatedDimention == 0 ? -1 : static_cast<int64_t>(groupCount),
calculatedDimention == 1 ? -1 : static_cast<int64_t>(outputChannelsCount / groupCount),
static_cast<int64_t>(inputChannelsPerGroup),
static_cast<int64_t>(kernelSize),
static_cast<int64_t>(kernelSize) }),
true);
if (addReshape) {
weights = std::make_shared<ngraph::opset1::Reshape>(
weights,
ngraph::opset1::Constant::create(
element::i64,
Shape{ static_cast<size_t>(rankLength) + 1ul },
rankLength == 3 ?
std::vector<int64_t> {
calculatedDimention == 0 ? -1 : static_cast<int64_t>(groupCount),
calculatedDimention == 1 ? -1 : static_cast<int64_t>(outputChannelsCount / groupCount),
static_cast<int64_t>(inputChannelsPerGroup),
static_cast<int64_t>(kernelSize) } :
std::vector<int64_t> {
calculatedDimention == 0 ? -1 : static_cast<int64_t>(groupCount),
calculatedDimention == 1 ? -1 : static_cast<int64_t>(outputChannelsCount / groupCount),
static_cast<int64_t>(inputChannelsPerGroup),
static_cast<int64_t>(kernelSize),
static_cast<int64_t>(kernelSize) }),
true);
}
}
return weights;
@ -253,7 +260,8 @@ std::shared_ptr<ngraph::Function> GroupConvolutionFunction::get(
const ngraph::builder::subgraph::DequantizationOperations& dequantizationOnWeights,
const ngraph::element::Type precisionAfterOperation,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
const ngraph::element::Type precisionAfterDequantization) {
const ngraph::element::Type precisionAfterDequantization,
const bool addReshape) {
const auto rankLength = inputShape.rank().is_dynamic() ? 4 : inputShape.rank().get_length();
OPENVINO_ASSERT(rankLength == 3 || rankLength == 4, "not supported input shape rank: ", rankLength);
@ -269,9 +277,6 @@ std::shared_ptr<ngraph::Function> GroupConvolutionFunction::get(
const size_t outputChannelsInGroup = outputChannelsCount / groupCount;
const size_t weightsSize = weightsConst->cast_vector<float>().size();
if ((weightsSize != 1ul) && (weightsSize != (inputChannelsCount * outputChannelsCount))) {
throw std::runtime_error("unexpected actual weights values size");
}
std::shared_ptr<ngraph::Node> weights;
if (fakeQuantizeOnWeights.empty() && dequantizationOnWeights.empty()) {
@ -293,7 +298,8 @@ std::shared_ptr<ngraph::Function> GroupConvolutionFunction::get(
kernelSize,
weightsConst->cast_vector<float>(),
fakeQuantizeOnWeights,
dequantizationOnWeights);
dequantizationOnWeights,
addReshape);
}
auto convolutionOriginal = ngraph::opset1::GroupConvolution(