Fix Add/MulFQFusion transformations (#10252)
This commit is contained in:
@@ -167,7 +167,7 @@ TRANSFORMATIONS_API bool constantIsEqualTo(const std::shared_ptr<ngraph::op::Con
|
||||
|
||||
TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr<const ngraph::Function> &function);
|
||||
|
||||
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape);
|
||||
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::PartialShape &ref_shape, const ngraph::PartialShape &other_shape);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name,
|
||||
const ngraph::Output<ngraph::Node>& apply_to);
|
||||
|
||||
@@ -41,8 +41,16 @@ ngraph::pass::AddFakeQuantizeFusion::AddFakeQuantizeFusion() {
|
||||
auto add_const = std::dynamic_pointer_cast<opset5::Constant>(pattern_value_map.at(const_pattern).get_node_shared_ptr());
|
||||
if (!add_const)
|
||||
return false;
|
||||
std::shared_ptr<Node> new_const = add_const;
|
||||
|
||||
auto const_shape = add_const->get_shape();
|
||||
if (ngraph::op::util::check_for_broadcast(input.get_partial_shape(), const_shape)) {
|
||||
// We can't eliminate Add if Constant input broadcasts another input shape because
|
||||
// when we reconnect input from Add to FQ won't broadcast given input, so it will result
|
||||
// in shape collision.
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> new_const = add_const;
|
||||
size_t const_shape_size = shape_size(const_shape);
|
||||
bool is_single_value = const_shape_size == 1;
|
||||
|
||||
|
||||
@@ -41,12 +41,19 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
|
||||
if (!mul_const)
|
||||
return false;
|
||||
|
||||
auto const_shape = mul_const->get_shape();
|
||||
if (ngraph::op::util::check_for_broadcast(input.get_partial_shape(), const_shape)) {
|
||||
// We can't eliminate Multiply if Constant input broadcasts another input shape because
|
||||
// when we reconnect input from Multiply to FQ won't broadcast given input, so it will result
|
||||
// in shape collision.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mul_const_value = mul_const->cast_vector<float>();
|
||||
if (std::any_of(mul_const_value.begin(), mul_const_value.end(), [] (float f) -> bool { return f <= 0.0f; }))
|
||||
return false;
|
||||
|
||||
std::shared_ptr<Node> new_const = mul_const;
|
||||
auto const_shape = mul_const->get_shape();
|
||||
size_t const_shape_size = shape_size(const_shape);
|
||||
bool is_single_value = const_shape_size == 1;
|
||||
|
||||
|
||||
@@ -199,8 +199,7 @@ bool ngraph::pass::GroupedStridedSliceOptimizer::run_on_model(const std::shared_
|
||||
if (!valid_for_replacement) break;
|
||||
}
|
||||
|
||||
if (!valid_for_replacement) continue;
|
||||
if (output_to_partition.size() < 2) continue;
|
||||
if (!valid_for_replacement || output_to_partition.size() < 2 || axis == -1) continue;
|
||||
|
||||
std::sort(output_to_partition.begin(), output_to_partition.end(),
|
||||
[](OutputToPatrition lhs, OutputToPatrition rhs)
|
||||
|
||||
@@ -90,18 +90,19 @@ bool has_f16_constants(const std::shared_ptr<const ngraph::Function> &function)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape) {
|
||||
bool check_for_broadcast(const ngraph::PartialShape &ref_shape, const ngraph::PartialShape &other_shape) {
|
||||
// Check that other_shape doesn't broadcast ref_shape
|
||||
if (other_shape.size() > ref_shape.size()) {
|
||||
if (ref_shape.rank().is_dynamic() || other_shape.rank().is_dynamic() || other_shape.size() > ref_shape.size()) {
|
||||
return true;
|
||||
}
|
||||
auto ref_it = ref_shape.rbegin();
|
||||
auto other_it = other_shape.rbegin();
|
||||
// Check that other_shape dims are equal to ref_shape dims
|
||||
// In case if other_shape rank is less than ref_shape rank
|
||||
// we stop comparision and return true
|
||||
// we stop comparison and return true
|
||||
while (other_it != other_shape.rend()) {
|
||||
if (*other_it != *ref_it && *other_it != 1) {
|
||||
if ((other_it->is_dynamic() || other_it->get_length() != 1) &&
|
||||
(ref_it->is_dynamic() || ref_it->get_length() == 1)) {
|
||||
return true;
|
||||
}
|
||||
++other_it;
|
||||
|
||||
@@ -264,3 +264,16 @@ TEST_F(TransformationTestsF, NegativeAddFakeQuantizeFusionWithNonPerChannelConst
|
||||
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
manager.register_pass<pass::AddFakeQuantizeFusion>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, AddFakeQuantizeFusionWithBroadcastingConstant) {
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, ov::PartialShape{DYN, 3});
|
||||
auto add_const = opset5::Constant::create(element::f32, Shape{3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto add = std::make_shared<opset5::Add>(data, add_const);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 11);
|
||||
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
manager.register_pass<pass::AddFakeQuantizeFusion>();
|
||||
}
|
||||
@@ -244,3 +244,16 @@ TEST_F(TransformationTestsF, NegativeMulFakeQuantizeFusionWithNonPerChannelConst
|
||||
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
manager.register_pass<pass::MulFakeQuantizeFusion>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MulFakeQuantizeFusionWithBroadcastingConstant) {
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, ov::PartialShape{DYN, 3});
|
||||
auto mul_const = opset5::Constant::create(element::f32, Shape{3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto mul = std::make_shared<opset5::Multiply>(data, mul_const);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {1});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(mul, input_low, input_high, output_low, output_high, 11);
|
||||
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
manager.register_pass<pass::MulFakeQuantizeFusion>();
|
||||
}
|
||||
Reference in New Issue
Block a user