This commit is contained in:
Evgeny Kotov 2023-03-02 14:38:43 +01:00
parent d8e7b39edb
commit b1099c1c40
3 changed files with 28 additions and 13 deletions

View File

@ -56,14 +56,15 @@ TSUnaryForward::TSUnaryForward() {
MATCHER_SCOPE(TSUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{transpose_label});
auto fq_label = wrap_type<FakeQuantize>({transpose_label, any_input(), any_input(), any_input(), any_input()});
auto unary_op_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label});
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
const NodePair new_nodes = SwapNodes(transpose, unary);
@ -74,7 +75,7 @@ TSUnaryForward::TSUnaryForward() {
return true;
};
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TSUnaryForward");
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
@ -91,10 +92,12 @@ TSUnaryBackward::TSUnaryBackward() {
return HasSameOutputTransposeNodes(output);
};
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{any_input()},
unary_restrictions);
auto fq_label =
wrap_type<FakeQuantize>({any_input(), any_input(), any_input(), any_input(), any_input()}, unary_restrictions);
auto unary_op_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({any_input()},
unary_restrictions);
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
auto transpose_const_label = wrap_type<Constant>();
@ -104,9 +107,11 @@ TSUnaryBackward::TSUnaryBackward() {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary,
transpose_const,
/* input_indexes */ {0})) {
register_new_node(new_node);
}
unary->validate_and_infer_types();
@ -116,6 +121,6 @@ TSUnaryBackward::TSUnaryBackward() {
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TSUnaryBackward");
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -285,6 +285,7 @@ bool CanPropagateForwardThrough(Node* node) {
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(FakeQuantize, node);
return false;
}

View File

@ -85,6 +85,15 @@ NodePtr UnaryFactory<Convert>::create(const OutputVector& inputs) const {
return std::make_shared<Convert>(inputs[0], element::f64);
}
template <>
NodePtr UnaryFactory<FakeQuantize>::create(const OutputVector& inputs) const {
auto input_low = std::make_shared<Constant>(element::f32, Shape{1}, Shape{1});
auto input_high = std::make_shared<Constant>(element::f32, Shape{1}, Shape{20});
auto output_low = std::make_shared<Constant>(element::f32, Shape{1}, Shape{0});
auto output_high = std::make_shared<Constant>(element::f32, Shape{1}, Shape{10});
return std::make_shared<FakeQuantize>(inputs[0], input_low, input_high, output_low, output_high, 11);
}
template <typename UnaryT>
FactoryPtr CreateUnaryFactory(const std::string& type_name) {
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
@ -361,7 +370,7 @@ std::vector<FactoryPtr> unary_factories = {
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh), CREATE_UNARY_FACTORY(FakeQuantize)};
TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
FactoryPtr unary_factory;