[LPT] AddTransformation fix (#17076)

* [LPT] AddTransformation: constants on 0's input support

* AddTransformation: new test instances

* codestyle
This commit is contained in:
Vladislav Golubev
2023-04-24 13:15:01 +02:00
committed by GitHub
parent a3f14366d9
commit f410658d32
4 changed files with 80 additions and 37 deletions

View File

@@ -60,9 +60,9 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
}
static bool isTargetType(const std::shared_ptr<Node> node) {
return ov::is_type<opset1::Convolution>(node) ||
ov::is_type<opset1::GroupConvolution>(node) ||
ov::is_type<opset1::MatMul>(node);
return node != nullptr && (ov::is_type<opset1::Convolution>(node) ||
ov::is_type<opset1::GroupConvolution>(node) ||
ov::is_type<opset1::MatMul>(node));
}
static std::shared_ptr<Node> getDataParent(const std::shared_ptr<Node> branchData) {
@@ -72,30 +72,38 @@ static std::shared_ptr<Node> getDataParent(const std::shared_ptr<Node> branchDat
}
if (ov::marked_as_bias(parent)) {
const auto bias_parent = parent->get_input_node_shared_ptr(0);
// target node just before bias
if (isTargetType(bias_parent)) {
return bias_parent;
}
// between target node and bias are placed some DQ operations
const auto dq = NetworkHelper::getDequantization(parent->get_input_node_shared_ptr(0));
const auto data_node = dq.data.get_node_shared_ptr();
if (isTargetType(data_node)) {
return data_node;
// we need to check both inputs in order to handle the case with constant on 0's input
for (size_t i = 0; i < parent->get_input_size(); ++i) {
const auto bias_parent = parent->get_input_node_shared_ptr(i);
// target node just before bias
if (isTargetType(bias_parent)) {
return bias_parent;
}
// between target node and bias are placed some DQ operations
const auto dq = NetworkHelper::getDequantization(bias_parent);
const auto data_node = dq.data.get_node_shared_ptr();
if (isTargetType(data_node)) {
return data_node;
}
}
}
return parent;
}
static bool isBranchHaveMultipleConsumers(const std::shared_ptr<Node> branchData, const std::shared_ptr<Node> branchDataParent) {
auto several_consumers = [](const std::shared_ptr<ov::Node>& node) {
return node->get_output_size() != 1 || node->get_output_target_inputs(0).size() != 1;
};
auto parent = branchData;
while (parent != branchDataParent) {
if ((parent->get_output_size() != 1ul) || (parent->get_output_target_inputs(0).size() != 1ul)) {
if (several_consumers(parent)) {
return true;
}
parent = parent->get_input_node_shared_ptr(0);
const auto new_parent = parent->get_input_node_shared_ptr(0);
parent = !ov::is_type<opset1::Constant>(new_parent) ? new_parent : parent->get_input_node_shared_ptr(1);
}
return (parent->get_output_size() != 1ul) || (parent->get_output_target_inputs(0).size() != 1ul);
return several_consumers(parent);
}
// return branch index with FP32 precision after eltwise transformation

View File

@@ -69,6 +69,7 @@ public:
Actual actual;
Expected expected;
std::string additionalLayer;
std::string postops_configuration;
};
typedef std::tuple<ngraph::element::Type,
@@ -94,7 +95,8 @@ public:
testValues.actual.dequantization2,
testValues.constInput,
testValues.actual.constValues,
testValues.additionalLayer);
testValues.additionalLayer,
testValues.postops_configuration);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(testValues.params);
@@ -120,7 +122,8 @@ public:
testValues.constInput == -1 ? -1 : 1,
testValues.expected.constValues,
testValues.additionalLayer,
testValues.expected.operationType);
testValues.expected.operationType,
testValues.postops_configuration);
}
static std::string getTestCaseName(testing::TestParamInfo<AddTransformationParams> obj) {
@@ -133,7 +136,8 @@ public:
<< "_" << testValues.actual.precision1 << "_" << testValues.actual.dequantization1 << "_"
<< testValues.actual.precision2 << "_" << testValues.actual.dequantization2 << "_"
<< testValues.constInput << "_" << testValues.actual.constValues << "_" << testValues.additionalLayer
<< "_" << (testValues.params.updatePrecisions ? "true" : "false");
<< "_" << testValues.postops_configuration << "_"
<< (testValues.params.updatePrecisions ? "true" : "false");
return result.str();
}
};
@@ -439,6 +443,23 @@ const std::vector<AddTransformationTestValues> testValuesWithoutConstantBranches
{{}, {}, {10.f}},
{}},
"convolution"},
// convolution before FQ
{false,
-1,
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32}, {7.f}, {10.f}},
ngraph::element::u8,
{{ngraph::element::f32}, {3.f}, {5.f}},
{}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {17.f}, {0.5f}},
{{}, {}, {10.f}},
{}},
"convolution",
"bias_on_zero_input"},
// convolution with multiple consumers before FQ ( FP32 on other branch due to possible quantize fusing )
{false,
-1,