[LPT] Fixed an incorrect condition & added test to MoveFakeQuantize transformation (#10009)

* fixed an incorrect condition & added test

* fixed an incorrect condition & added test
This commit is contained in:
Nikita Demashov
2022-02-07 12:32:49 +03:00
committed by GitHub
parent b365e67561
commit 74fa60cf86
2 changed files with 65 additions and 1 deletions

View File

@@ -86,7 +86,7 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
const auto concat_axis = concat_node->get_concatenation_axis();
for (size_t i = 0; i < 4; i++) {
curr_constants[i] = as_type_ptr<opset1::Constant>(fq->get_input_node_shared_ptr(i + 1));
if (!multi_chanels && curr_constants[i]->get_shape().size() > (concat_axis + 1ul) && curr_constants[i]->get_shape()[concat_axis] != 1) {
if (!multi_chanels && curr_constants[i]->get_shape().size() > concat_axis && curr_constants[i]->get_shape()[concat_axis] != 1) {
multi_chanels = true;
}
}

View File

@@ -585,4 +585,68 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn({ false })),
MoveFakeQuantizeTransformation::getTestCaseName);
} // namespace testValues2
namespace testValues3 {
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
ngraph::element::f16
};
const std::vector<std::vector<ngraph::PartialShape>> shapes = {
{{ 1, 1}, { 1, 2}},
{{ 4, 1}, { 4, 2}}
};
const std::vector<MoveFakeQuantizeTransformationTestValues> testValues = {
// 2D shape
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
2,
{},
{},
{},
"",
{
256ul,
{{1, 3}, {1, 3}, {}, {}},
{-31.7f, -35.7f, -49.1f},
{277.8f, 267.f, 254.9f},
{-2.6f}, {2.6f},
},
{},
{}
},
{
2,
{
{256ul,
{{1, 1}, {1, 1}, {}, {}},
{-31.7f}, {277.8f}, {-2.6}, {2.6f}},
{256ul,
{{1, 2}, {1, 2}, {}, {}},
{-35.7f, -49.1f},
{267.f, 254.9f},
{-2.6f}, {2.6f}}
},
{},
{},
"",
{},
{},
{},
}
},
};
INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
MoveFakeQuantizeTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues),
::testing::ValuesIn({ false })),
MoveFakeQuantizeTransformation::getTestCaseName);
} // namespace testValues3
} // namespace