Don't use could_propagate in MarkDequantizationSubgraph (#15325)

This commit is contained in:
Mateusz Tabaka 2023-02-13 11:28:38 +01:00 committed by GitHub
parent b9a1b45a82
commit 9a540e61dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 144 additions and 5 deletions

View File

@ -346,3 +346,126 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}
TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstantWeights) {
// Input graph:
//
// Parameter
// |F32
// |
// FakeQuantize Constant
// |F32 |I8
// | |
// Convert Constant Clamp Constant
// |U8 |U8 |I8 |I8
// | | | |
// Convert Convert(DCF) Convert(DCF) Convert(DCF)
// \FP32 /FP32 |FP32 /F32
// \ / | /
// Subtract Constant Subtract Constant
// \FP32 /FP32 |FP32 /FP32
// \ / | /
// Multiply Multiply
// \FP32 /FP32
// \ /
// Convolution
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
{
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
std::shared_ptr<Node> activations =
std::make_shared<opset10::FakeQuantize>(parameter,
opset10::Constant::create(element::f32, Shape{}, {0}),
opset10::Constant::create(element::f32, Shape{}, {20}),
opset10::Constant::create(element::f32, Shape{}, {0}),
opset10::Constant::create(element::f32, Shape{}, {254}),
255);
{
auto first_convert = std::make_shared<opset10::Convert>(activations, element::u8);
auto second_convert = std::make_shared<opset10::Convert>(first_convert, element::f32);
auto zero_point = opset10::Constant::create(element::u8, Shape{}, {127});
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
auto subtract = std::make_shared<opset10::Subtract>(second_convert, convert_on_zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
activations = multiply;
}
std::shared_ptr<Node> weights = opset10::Constant::create(element::i8, Shape{4, 16, 1, 1}, {-3});
{
auto clamp = std::make_shared<opset10::Clamp>(weights, -2, 2);
auto convert = std::make_shared<opset10::Convert>(clamp, element::f32);
auto zero_point = opset10::Constant::create(element::i8, Shape{}, {127});
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
auto subtract = std::make_shared<opset10::Subtract>(convert, convert_on_zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
weights = multiply;
}
auto conv = std::make_shared<opset10::Convolution>(activations,
weights,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
function = std::make_shared<Model>(conv, ParameterVector{parameter});
}
manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();
{
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
std::shared_ptr<Node> activations =
std::make_shared<opset10::FakeQuantize>(parameter,
opset10::Constant::create(element::f32, Shape{}, {0}),
opset10::Constant::create(element::f32, Shape{}, {20}),
opset10::Constant::create(element::f32, Shape{}, {0}),
opset10::Constant::create(element::f32, Shape{}, {254}),
255);
{
auto first_convert = std::make_shared<opset10::Convert>(activations, element::u8);
auto second_convert = std::make_shared<opset10::Convert>(first_convert, element::f32);
auto zero_point = opset10::Constant::create(element::u8, Shape{}, {127});
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
pass::disable_constant_folding(convert_on_zero_point);
auto subtract = std::make_shared<opset10::Subtract>(second_convert, convert_on_zero_point);
mark_as_dequantization_node(subtract);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
mark_as_dequantization_node(multiply);
activations = multiply;
}
std::shared_ptr<Node> weights = opset10::Constant::create(element::i8, Shape{4, 16, 1, 1}, {-2});
{
// Clamp was constantfolded
auto convert = std::make_shared<opset10::Convert>(weights, element::f32);
pass::disable_constant_folding(convert);
auto zero_point = opset10::Constant::create(element::i8, Shape{}, {127});
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
pass::disable_constant_folding(convert_on_zero_point);
auto subtract = std::make_shared<opset10::Subtract>(convert, convert_on_zero_point);
mark_as_dequantization_node(subtract);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
mark_as_dequantization_node(multiply);
weights = multiply;
}
auto conv = std::make_shared<opset10::Convolution>(activations,
weights,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
function_ref = std::make_shared<Model>(conv, ParameterVector{parameter});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

View File

@ -10,9 +10,26 @@
#include <transformations/rt_info/dequantization_node.hpp>
#include <transformations/rt_info/disable_constant_folding.hpp>
#include "bound_evaluation_util.hpp"
static bool is_constfoldable(const ov::Output<ov::Node>& output) {
auto status = true;
std::deque<ov::Node*> nodes_to_calculate = {output.get_node()};
using namespace ngraph;
while (status && !nodes_to_calculate.empty()) {
auto current_node = nodes_to_calculate.front();
nodes_to_calculate.pop_front();
if (current_node->get_input_size() == 0 && !ov::is_type<ov::op::v0::Constant>(current_node)) {
status = false;
} else {
// not a leaf, not a shape_of -- continue to search
for (const auto& input_value : current_node->input_values()) {
const auto& input_node = input_value.get_node();
nodes_to_calculate.push_front(input_node);
}
}
}
return status;
}
ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions) {
// Dequantization subgraph may have two forms: with and without Subtract
@ -51,14 +68,13 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
}
}
// validation by Convert operation input precisions
const auto& input_precision = input->get_output_element_type(0);
// validation by Convert operation input precisions
if (std::find(precisions.begin(), precisions.end(), input_precision) == precisions.end()) {
return false;
}
std::vector<Node*> tmp;
if (ov::could_propagate(input, tmp)) {
if (is_constfoldable(input)) {
// disable ConstantFolding if dequantization subgraph is on constant data
ov::disable_constant_folding(convert);
}