Don't use could_propagate in MarkDequantizationSubgraph (#15325)
This commit is contained in:
parent
b9a1b45a82
commit
9a540e61dc
@ -346,3 +346,126 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstantWeights) {
|
||||
// Input graph:
|
||||
//
|
||||
// Parameter
|
||||
// |F32
|
||||
// |
|
||||
// FakeQuantize Constant
|
||||
// |F32 |I8
|
||||
// | |
|
||||
// Convert Constant Clamp Constant
|
||||
// |U8 |U8 |I8 |I8
|
||||
// | | | |
|
||||
// Convert Convert(DCF) Convert(DCF) Convert(DCF)
|
||||
// \FP32 /FP32 |FP32 /F32
|
||||
// \ / | /
|
||||
// Subtract Constant Subtract Constant
|
||||
// \FP32 /FP32 |FP32 /FP32
|
||||
// \ / | /
|
||||
// Multiply Multiply
|
||||
// \FP32 /FP32
|
||||
// \ /
|
||||
// Convolution
|
||||
//
|
||||
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
|
||||
// are marked with 'DequantizationNode' attribute.
|
||||
// Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
|
||||
|
||||
{
|
||||
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
|
||||
std::shared_ptr<Node> activations =
|
||||
std::make_shared<opset10::FakeQuantize>(parameter,
|
||||
opset10::Constant::create(element::f32, Shape{}, {0}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {20}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {0}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {254}),
|
||||
255);
|
||||
{
|
||||
auto first_convert = std::make_shared<opset10::Convert>(activations, element::u8);
|
||||
auto second_convert = std::make_shared<opset10::Convert>(first_convert, element::f32);
|
||||
auto zero_point = opset10::Constant::create(element::u8, Shape{}, {127});
|
||||
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
|
||||
auto subtract = std::make_shared<opset10::Subtract>(second_convert, convert_on_zero_point);
|
||||
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
|
||||
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
|
||||
activations = multiply;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> weights = opset10::Constant::create(element::i8, Shape{4, 16, 1, 1}, {-3});
|
||||
{
|
||||
auto clamp = std::make_shared<opset10::Clamp>(weights, -2, 2);
|
||||
auto convert = std::make_shared<opset10::Convert>(clamp, element::f32);
|
||||
auto zero_point = opset10::Constant::create(element::i8, Shape{}, {127});
|
||||
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
|
||||
auto subtract = std::make_shared<opset10::Subtract>(convert, convert_on_zero_point);
|
||||
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
|
||||
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
|
||||
weights = multiply;
|
||||
}
|
||||
|
||||
auto conv = std::make_shared<opset10::Convolution>(activations,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
function = std::make_shared<Model>(conv, ParameterVector{parameter});
|
||||
}
|
||||
|
||||
manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8, element::i8});
|
||||
manager.register_pass<pass::ConstantFolding>();
|
||||
|
||||
{
|
||||
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
|
||||
std::shared_ptr<Node> activations =
|
||||
std::make_shared<opset10::FakeQuantize>(parameter,
|
||||
opset10::Constant::create(element::f32, Shape{}, {0}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {20}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {0}),
|
||||
opset10::Constant::create(element::f32, Shape{}, {254}),
|
||||
255);
|
||||
{
|
||||
auto first_convert = std::make_shared<opset10::Convert>(activations, element::u8);
|
||||
auto second_convert = std::make_shared<opset10::Convert>(first_convert, element::f32);
|
||||
auto zero_point = opset10::Constant::create(element::u8, Shape{}, {127});
|
||||
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
|
||||
pass::disable_constant_folding(convert_on_zero_point);
|
||||
auto subtract = std::make_shared<opset10::Subtract>(second_convert, convert_on_zero_point);
|
||||
mark_as_dequantization_node(subtract);
|
||||
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
|
||||
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
|
||||
mark_as_dequantization_node(multiply);
|
||||
activations = multiply;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> weights = opset10::Constant::create(element::i8, Shape{4, 16, 1, 1}, {-2});
|
||||
{
|
||||
// Clamp was constantfolded
|
||||
auto convert = std::make_shared<opset10::Convert>(weights, element::f32);
|
||||
pass::disable_constant_folding(convert);
|
||||
auto zero_point = opset10::Constant::create(element::i8, Shape{}, {127});
|
||||
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
|
||||
pass::disable_constant_folding(convert_on_zero_point);
|
||||
auto subtract = std::make_shared<opset10::Subtract>(convert, convert_on_zero_point);
|
||||
mark_as_dequantization_node(subtract);
|
||||
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
|
||||
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
|
||||
mark_as_dequantization_node(multiply);
|
||||
weights = multiply;
|
||||
}
|
||||
|
||||
auto conv = std::make_shared<opset10::Convolution>(activations,
|
||||
weights,
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
function_ref = std::make_shared<Model>(conv, ParameterVector{parameter});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
}
|
||||
|
@ -10,9 +10,26 @@
|
||||
#include <transformations/rt_info/dequantization_node.hpp>
|
||||
#include <transformations/rt_info/disable_constant_folding.hpp>
|
||||
|
||||
#include "bound_evaluation_util.hpp"
|
||||
static bool is_constfoldable(const ov::Output<ov::Node>& output) {
|
||||
auto status = true;
|
||||
std::deque<ov::Node*> nodes_to_calculate = {output.get_node()};
|
||||
|
||||
using namespace ngraph;
|
||||
while (status && !nodes_to_calculate.empty()) {
|
||||
auto current_node = nodes_to_calculate.front();
|
||||
nodes_to_calculate.pop_front();
|
||||
|
||||
if (current_node->get_input_size() == 0 && !ov::is_type<ov::op::v0::Constant>(current_node)) {
|
||||
status = false;
|
||||
} else {
|
||||
// not a leaf, not a shape_of -- continue to search
|
||||
for (const auto& input_value : current_node->input_values()) {
|
||||
const auto& input_node = input_value.get_node();
|
||||
nodes_to_calculate.push_front(input_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions) {
|
||||
// Dequantization subgraph may have two forms: with and without Subtract
|
||||
@ -51,14 +68,13 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
|
||||
}
|
||||
}
|
||||
|
||||
// validation by Convert operation input precisions
|
||||
const auto& input_precision = input->get_output_element_type(0);
|
||||
// validation by Convert operation input precisions
|
||||
if (std::find(precisions.begin(), precisions.end(), input_precision) == precisions.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Node*> tmp;
|
||||
if (ov::could_propagate(input, tmp)) {
|
||||
if (is_constfoldable(input)) {
|
||||
// disable ConstantFolding if dequantization subgraph is on constant data
|
||||
ov::disable_constant_folding(convert);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user