Add FakeQuantize op support in TS transformations (#17243)
* Add FQ op support in TS transformations * codestyle * Mark FQ as supported op in the TS ops list
This commit is contained in:
parent
22bb3af7df
commit
40bf400b18
@ -6,6 +6,7 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/op/prelu.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
@ -25,7 +26,8 @@ TSBinaryForward::TSBinaryForward() {
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
op::util::BinaryElementwiseLogical,
|
||||
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
|
||||
ov::op::v0::PRelu,
|
||||
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||
});
|
||||
|
||||
@ -62,7 +64,8 @@ TSBinaryBackward::TSBinaryBackward() {
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
op::util::BinaryElementwiseLogical,
|
||||
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
|
||||
ov::op::v0::PRelu,
|
||||
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
|
@ -346,6 +346,7 @@ bool CanPropagateForwardThrough(Node* node) {
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node)
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -217,6 +217,22 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
return std::make_shared<ReshapeFactory>(type_name);
|
||||
}
|
||||
|
||||
class FakeQuantizeFactory : public IFactory {
|
||||
public:
|
||||
explicit FakeQuantizeFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<FakeQuantize>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
parent_nodes[3],
|
||||
parent_nodes[4],
|
||||
128);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
|
||||
return std::make_shared<FakeQuantizeFactory>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -255,6 +271,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
#undef CREATE_RESHAPE_FACTORY
|
||||
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
||||
|
||||
#undef CREATE_FQ_FACTORY
|
||||
#define CREATE_FQ_FACTORY(type_name) common::CreateFakeQuantizeFactory(#type_name)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
vector<FactoryPtr> unary_factories = {
|
||||
@ -393,6 +412,42 @@ auto test_forward_binary = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());
|
||||
|
||||
auto test_forward_fq = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {1}),
|
||||
parameter(element::f32, {55, 1, 1, 1}),
|
||||
parameter(element::f32, {55, 55, 1, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec = out_vec;
|
||||
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
|
||||
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {1, 2, 3, 4}}};
|
||||
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQForward, TSTestFixture, test_forward_fq());
|
||||
|
||||
auto test_forward_concat = []() {
|
||||
TestCase test_case;
|
||||
|
||||
@ -867,6 +922,42 @@ auto test_backward_binary = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());
|
||||
|
||||
auto test_backward_fq = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1}),
|
||||
parameter(element::f32, {1, 96, 55, 1}),
|
||||
parameter(element::f32, {1, 96, 1, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec = out_vec;
|
||||
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
|
||||
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
|
||||
return new_out_vec;
|
||||
};
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {0, 1, 2, 3, 4}}};
|
||||
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQBackward, TSTestFixture, test_backward_fq());
|
||||
|
||||
auto test_backward_concat = []() {
|
||||
TestCase test_case;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user