Add FakeQuantize op support in TS transformations (#17243)
* Add FQ op support in TS transformations * codestyle * Mark FQ as supported op in the TS ops list
This commit is contained in:
parent
22bb3af7df
commit
40bf400b18
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "openvino/op/constant.hpp"
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/fake_quantize.hpp"
|
||||||
#include "openvino/op/prelu.hpp"
|
#include "openvino/op/prelu.hpp"
|
||||||
#include "openvino/op/transpose.hpp"
|
#include "openvino/op/transpose.hpp"
|
||||||
#include "openvino/op/util/op_types.hpp"
|
#include "openvino/op/util/op_types.hpp"
|
||||||
@ -25,7 +26,8 @@ TSBinaryForward::TSBinaryForward() {
|
|||||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||||
op::util::BinaryElementwiseComparison,
|
op::util::BinaryElementwiseComparison,
|
||||||
op::util::BinaryElementwiseLogical,
|
op::util::BinaryElementwiseLogical,
|
||||||
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
|
ov::op::v0::PRelu,
|
||||||
|
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
|
||||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -62,7 +64,8 @@ TSBinaryBackward::TSBinaryBackward() {
|
|||||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||||
op::util::BinaryElementwiseComparison,
|
op::util::BinaryElementwiseComparison,
|
||||||
op::util::BinaryElementwiseLogical,
|
op::util::BinaryElementwiseLogical,
|
||||||
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
|
ov::op::v0::PRelu,
|
||||||
|
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
|
||||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -346,6 +346,7 @@ bool CanPropagateForwardThrough(Node* node) {
|
|||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node)
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node)
|
||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node)
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node)
|
||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node)
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node)
|
||||||
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node)
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -217,6 +217,22 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
|||||||
return std::make_shared<ReshapeFactory>(type_name);
|
return std::make_shared<ReshapeFactory>(type_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class FakeQuantizeFactory : public IFactory {
|
||||||
|
public:
|
||||||
|
explicit FakeQuantizeFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||||
|
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||||
|
return std::make_shared<FakeQuantize>(parent_nodes[0],
|
||||||
|
parent_nodes[1],
|
||||||
|
parent_nodes[2],
|
||||||
|
parent_nodes[3],
|
||||||
|
parent_nodes[4],
|
||||||
|
128);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
|
||||||
|
return std::make_shared<FakeQuantizeFactory>(type_name);
|
||||||
|
}
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
#undef CREATE_UNARY_FACTORY
|
#undef CREATE_UNARY_FACTORY
|
||||||
@ -255,6 +271,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
|||||||
#undef CREATE_RESHAPE_FACTORY
|
#undef CREATE_RESHAPE_FACTORY
|
||||||
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
||||||
|
|
||||||
|
#undef CREATE_FQ_FACTORY
|
||||||
|
#define CREATE_FQ_FACTORY(type_name) common::CreateFakeQuantizeFactory(#type_name)
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
vector<FactoryPtr> unary_factories = {
|
vector<FactoryPtr> unary_factories = {
|
||||||
@ -393,6 +412,42 @@ auto test_forward_binary = []() {
|
|||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());
|
||||||
|
|
||||||
|
auto test_forward_fq = []() {
|
||||||
|
TestCase test_case;
|
||||||
|
|
||||||
|
// Initialize common attributes
|
||||||
|
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward);
|
||||||
|
test_case.num_main_ops = {1, 10};
|
||||||
|
test_case.inputs_to_main = {
|
||||||
|
parameter(element::f32, {1, 96, 55, 55}),
|
||||||
|
parameter(element::f32, {55, 55, 96, 1}),
|
||||||
|
parameter(element::f32, {1}),
|
||||||
|
parameter(element::f32, {55, 1, 1, 1}),
|
||||||
|
parameter(element::f32, {55, 55, 1, 1}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test model description:
|
||||||
|
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||||
|
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||||
|
test_case.model.model_template = create_model;
|
||||||
|
|
||||||
|
// Reference model description:
|
||||||
|
auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||||
|
OutputVector new_out_vec = out_vec;
|
||||||
|
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
|
||||||
|
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
|
||||||
|
return new_out_vec;
|
||||||
|
};
|
||||||
|
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {1, 2, 3, 4}}};
|
||||||
|
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||||
|
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||||
|
test_case.model_ref.model_template = create_model;
|
||||||
|
|
||||||
|
return wrapper(test_case);
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQForward, TSTestFixture, test_forward_fq());
|
||||||
|
|
||||||
auto test_forward_concat = []() {
|
auto test_forward_concat = []() {
|
||||||
TestCase test_case;
|
TestCase test_case;
|
||||||
|
|
||||||
@ -867,6 +922,42 @@ auto test_backward_binary = []() {
|
|||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());
|
||||||
|
|
||||||
|
auto test_backward_fq = []() {
|
||||||
|
TestCase test_case;
|
||||||
|
|
||||||
|
// Initialize common attributes
|
||||||
|
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward);
|
||||||
|
test_case.num_main_ops = {1, 10};
|
||||||
|
test_case.inputs_to_main = {
|
||||||
|
parameter(element::f32, {1, 96, 55, 55}),
|
||||||
|
parameter(element::f32, {1, 96, 55, 55}),
|
||||||
|
parameter(element::f32, {1}),
|
||||||
|
parameter(element::f32, {1, 96, 55, 1}),
|
||||||
|
parameter(element::f32, {1, 96, 1, 1}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test model description:
|
||||||
|
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||||
|
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||||
|
test_case.model.model_template = create_model;
|
||||||
|
|
||||||
|
auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||||
|
OutputVector new_out_vec = out_vec;
|
||||||
|
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
|
||||||
|
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
|
||||||
|
return new_out_vec;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reference model description:
|
||||||
|
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {0, 1, 2, 3, 4}}};
|
||||||
|
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||||
|
test_case.model_ref.model_template = create_model;
|
||||||
|
|
||||||
|
return wrapper(test_case);
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQBackward, TSTestFixture, test_backward_fq());
|
||||||
|
|
||||||
auto test_backward_concat = []() {
|
auto test_backward_concat = []() {
|
||||||
TestCase test_case;
|
TestCase test_case;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user