initial
This commit is contained in:
parent
d8e7b39edb
commit
b1099c1c40
@ -56,14 +56,15 @@ TSUnaryForward::TSUnaryForward() {
|
||||
MATCHER_SCOPE(TSUnaryForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
|
||||
auto unary_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
|
||||
{transpose_label});
|
||||
auto fq_label = wrap_type<FakeQuantize>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
auto unary_op_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label});
|
||||
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
|
||||
|
||||
const NodePair new_nodes = SwapNodes(transpose, unary);
|
||||
|
||||
@ -74,7 +75,7 @@ TSUnaryForward::TSUnaryForward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TSUnaryForward");
|
||||
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
@ -91,10 +92,12 @@ TSUnaryBackward::TSUnaryBackward() {
|
||||
return HasSameOutputTransposeNodes(output);
|
||||
};
|
||||
|
||||
auto unary_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
|
||||
{any_input()},
|
||||
unary_restrictions);
|
||||
auto fq_label =
|
||||
wrap_type<FakeQuantize>({any_input(), any_input(), any_input(), any_input(), any_input()}, unary_restrictions);
|
||||
auto unary_op_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({any_input()},
|
||||
unary_restrictions);
|
||||
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
@ -104,9 +107,11 @@ TSUnaryBackward::TSUnaryBackward() {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary,
|
||||
transpose_const,
|
||||
/* input_indexes */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
unary->validate_and_infer_types();
|
||||
@ -116,6 +121,6 @@ TSUnaryBackward::TSUnaryBackward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TSUnaryBackward");
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -285,6 +285,7 @@ bool CanPropagateForwardThrough(Node* node) {
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(FakeQuantize, node);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -85,6 +85,15 @@ NodePtr UnaryFactory<Convert>::create(const OutputVector& inputs) const {
|
||||
return std::make_shared<Convert>(inputs[0], element::f64);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<FakeQuantize>::create(const OutputVector& inputs) const {
|
||||
auto input_low = std::make_shared<Constant>(element::f32, Shape{1}, Shape{1});
|
||||
auto input_high = std::make_shared<Constant>(element::f32, Shape{1}, Shape{20});
|
||||
auto output_low = std::make_shared<Constant>(element::f32, Shape{1}, Shape{0});
|
||||
auto output_high = std::make_shared<Constant>(element::f32, Shape{1}, Shape{10});
|
||||
return std::make_shared<FakeQuantize>(inputs[0], input_low, input_high, output_low, output_high, 11);
|
||||
}
|
||||
|
||||
template <typename UnaryT>
|
||||
FactoryPtr CreateUnaryFactory(const std::string& type_name) {
|
||||
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
|
||||
@ -361,7 +370,7 @@ std::vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh), CREATE_UNARY_FACTORY(FakeQuantize)};
|
||||
|
||||
TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
|
||||
FactoryPtr unary_factory;
|
||||
|
Loading…
Reference in New Issue
Block a user