Add ConvertSubtractWithConstant to MOCTransformations (#17058)
* Add ConvertSubtractWithConstant to MOCTransformations Ticket: CVS-62419 * fix test_mo_import_from_memory tests * move test file --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
da4316845f
commit
dfaa4e7bd6
@ -13,6 +13,7 @@ namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertSubtract;
|
||||
class TRANSFORMATIONS_API ConvertSubtractWithConstant;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@ -22,3 +23,9 @@ public:
|
||||
OPENVINO_RTTI("ConvertSubtract", "0");
|
||||
ConvertSubtract();
|
||||
};
|
||||
|
||||
class ov::pass::ConvertSubtractWithConstant : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertSubtractWithConstant", "0");
|
||||
ConvertSubtractWithConstant();
|
||||
};
|
||||
|
@ -76,6 +76,7 @@
|
||||
#include <transformations/op_conversions/convert_divide.hpp>
|
||||
#include <transformations/op_conversions/convert_negative.hpp>
|
||||
#include <transformations/op_conversions/convert_scatter_elements_to_scatter.hpp>
|
||||
#include <transformations/op_conversions/convert_subtract.hpp>
|
||||
#include <transformations/op_conversions/convert_ti_to_sequences.hpp>
|
||||
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
|
||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
||||
@ -212,6 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
auto decomp = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
ADD_MATCHER(decomp, BatchNormDecomposition)
|
||||
ADD_MATCHER(decomp, ConvertDivideWithConstant)
|
||||
ADD_MATCHER(decomp, ConvertSubtractWithConstant)
|
||||
ADD_MATCHER(decomp, ConvertNegative)
|
||||
|
||||
manager.register_pass<ov::pass::LinOpSequenceFusion>();
|
||||
|
@ -4,46 +4,74 @@
|
||||
|
||||
#include "transformations/op_conversions/convert_subtract.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
#include <openvino/core/validation_util.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <vector>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/rt_info/dequantization_node.hpp"
|
||||
|
||||
ov::pass::ConvertSubtract::ConvertSubtract() {
|
||||
using namespace ov;
|
||||
|
||||
static bool convert_subtract(const std::shared_ptr<Node>& node) {
|
||||
auto sub = std::dynamic_pointer_cast<opset1::Subtract>(node);
|
||||
if (!sub) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ov::is_dequantization_node(sub)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sub->get_input_element_type(1).is_signed()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> neg =
|
||||
std::make_shared<opset1::Multiply>(sub->input_value(1),
|
||||
opset1::Constant::create(sub->get_input_element_type(1), Shape{}, {-1}));
|
||||
NodeVector new_nodes;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (auto constant = ov::get_constant_from_source(neg)) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
neg = constant;
|
||||
} else {
|
||||
new_nodes.push_back(neg);
|
||||
}
|
||||
|
||||
auto add = std::make_shared<opset1::Add>(sub->input_value(0), neg);
|
||||
new_nodes.push_back(add);
|
||||
|
||||
add->set_friendly_name(sub->get_friendly_name());
|
||||
copy_runtime_info(sub, new_nodes);
|
||||
replace_node(sub, add);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
pass::ConvertSubtract::ConvertSubtract() {
|
||||
MATCHER_SCOPE(ConvertSubtract);
|
||||
auto sub = ngraph::pattern::wrap_type<ov::opset1::Subtract>();
|
||||
auto sub = pattern::wrap_type<opset1::Subtract>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto sub = std::dynamic_pointer_cast<ov::opset1::Subtract>(m.get_match_root());
|
||||
if (!sub) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ov::is_dequantization_node(sub)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sub->get_input_element_type(1).is_signed()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto neg = std::make_shared<ov::opset1::Multiply>(
|
||||
sub->input(1).get_source_output(),
|
||||
opset1::Constant::create(sub->get_input_element_type(1), Shape{}, {-1}));
|
||||
|
||||
auto add = std::make_shared<ov::opset1::Add>(sub->input(0).get_source_output(), neg);
|
||||
|
||||
add->set_friendly_name(sub->get_friendly_name());
|
||||
ngraph::copy_runtime_info(sub, {neg, add});
|
||||
ngraph::replace_node(sub, add);
|
||||
|
||||
return true;
|
||||
auto node = m.get_match_root();
|
||||
return convert_subtract(node);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
|
||||
auto m = std::make_shared<pattern::Matcher>(sub, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
pass::ConvertSubtractWithConstant::ConvertSubtractWithConstant() {
|
||||
MATCHER_SCOPE(ConvertSubtractWithConstant);
|
||||
auto sub = pattern::wrap_type<opset1::Subtract>({pattern::any_input(), pattern::wrap_type<opset1::Constant>()});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
return convert_subtract(node);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(sub, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -0,0 +1,91 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <transformations/low_precision/mark_dequantization_subgraph.hpp>
|
||||
#include <transformations/op_conversions/convert_subtract.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov;
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertSubtract) {
|
||||
{
|
||||
auto data1 = std::make_shared<opset8::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto constant = opset8::Constant::create(element::f32, Shape{1}, {1.5});
|
||||
auto sub1 = std::make_shared<opset8::Subtract>(data1, constant);
|
||||
auto data2 = std::make_shared<opset8::Parameter>(element::f32, Shape{2});
|
||||
auto sub2 = std::make_shared<opset8::Subtract>(data1, data2);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{sub1, sub2}, ParameterVector{data1, data2});
|
||||
|
||||
manager.register_pass<pass::ConvertSubtract>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<opset8::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto constant = opset8::Constant::create(element::f32, Shape{1}, {-1.5});
|
||||
auto add1 = std::make_shared<opset8::Add>(data1, constant);
|
||||
auto data2 = std::make_shared<opset8::Parameter>(element::f32, Shape{2});
|
||||
auto neg = std::make_shared<opset8::Multiply>(data2, opset8::Constant::create(element::f32, Shape{}, {-1}));
|
||||
auto add2 = std::make_shared<opset8::Add>(data1, neg);
|
||||
|
||||
function_ref = std::make_shared<Model>(NodeVector{add1, add2}, ParameterVector{data1, data2});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertSubtractWithConstant) {
|
||||
{
|
||||
auto data1 = std::make_shared<opset8::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto constant = opset8::Constant::create(element::f32, Shape{1}, {1.5});
|
||||
auto sub1 = std::make_shared<opset8::Subtract>(data1, constant);
|
||||
auto data2 = std::make_shared<opset8::Parameter>(element::f32, Shape{2});
|
||||
auto sub2 = std::make_shared<opset8::Subtract>(data1, data2);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{sub1, sub2}, ParameterVector{data1, data2});
|
||||
|
||||
manager.register_pass<pass::ConvertSubtractWithConstant>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<opset8::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto constant = opset8::Constant::create(element::f32, Shape{1}, {-1.5});
|
||||
auto add = std::make_shared<opset8::Add>(data1, constant);
|
||||
auto data2 = std::make_shared<opset8::Parameter>(element::f32, Shape{2});
|
||||
auto sub = std::make_shared<opset8::Subtract>(data1, data2);
|
||||
|
||||
function_ref = std::make_shared<Model>(NodeVector{add, sub}, ParameterVector{data1, data2});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertSubtractDequantizationSubgraph) {
|
||||
auto data = std::make_shared<opset8::Parameter>(element::u8, Shape{1, 3, 2, 2});
|
||||
auto convert = std::make_shared<opset8::Convert>(data, element::f32);
|
||||
auto zero_point = opset8::Constant::create(element::f32, Shape{1}, {2});
|
||||
auto sub = std::make_shared<opset8::Subtract>(convert, zero_point);
|
||||
auto scale = opset8::Constant::create(element::f32, Shape{1}, {3});
|
||||
auto mul = std::make_shared<opset8::Multiply>(sub, scale);
|
||||
|
||||
function = std::make_shared<Model>(mul, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8});
|
||||
manager.register_pass<pass::ConvertSubtract>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertSubtractUnsignedType) {
|
||||
auto data = std::make_shared<opset8::Parameter>(element::u32, Shape{1, 3, 2, 2});
|
||||
auto constant = opset8::Constant::create(element::u32, Shape{1}, {2});
|
||||
auto sub = std::make_shared<opset8::Subtract>(data, constant);
|
||||
|
||||
function = std::make_shared<Model>(sub, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::ConvertSubtract>();
|
||||
}
|
@ -442,12 +442,12 @@ def create_pytorch_nn_module_mean_list(tmp_dir):
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1)
|
||||
sub2 = ov.opset8.subtract(param2, const2)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
relu = ov.opset8.relu(add)
|
||||
const1 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
add1 = ov.opset8.add(param1, const1)
|
||||
add2 = ov.opset8.add(param2, const2)
|
||||
add3 = ov.opset8.add(add1, add2)
|
||||
relu = ov.opset8.relu(add3)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
@ -466,12 +466,12 @@ def create_pytorch_nn_module_mean_list_default_no_compression(tmp_dir):
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1)
|
||||
sub2 = ov.opset8.subtract(param2, const2)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
relu = ov.opset8.relu(add)
|
||||
const1 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
add1 = ov.opset8.add(param1, const1)
|
||||
add2 = ov.opset8.add(param2, const2)
|
||||
add3 = ov.opset8.add(add1, add2)
|
||||
relu = ov.opset8.relu(add3)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
@ -487,12 +487,12 @@ def create_pytorch_nn_module_mean_list_compressin_enabled(tmp_dir):
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1)
|
||||
sub2 = ov.opset8.subtract(param2, const2)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
relu = ov.opset8.relu(add)
|
||||
const1 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32)
|
||||
add1 = ov.opset8.add(param1, const1)
|
||||
add2 = ov.opset8.add(param2, const2)
|
||||
add3 = ov.opset8.add(add1, add2)
|
||||
relu = ov.opset8.relu(add3)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
|
Loading…
Reference in New Issue
Block a user