[Opset13][PT FE] Update torch bitwise operators (#20339)
* Add opset-13 bitwise implementation * Improvements in test * Add transformation BitwiseOps->LogicalOps for bool * Improve existing tests to better tests dtypes * Disable transformatiions for supported bitwise ops * Improvebitwise test inputs * Update src/common/transformations/src/transformations/op_conversions/convert_bitwise_to_logical_bool.cpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> * Update src/common/transformations/src/transformations/op_conversions/convert_bitwise_to_logical_bool.cpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> * Update src/common/transformations/src/transformations/op_conversions/convert_bitwise_to_logical_bool.cpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> * Update src/common/transformations/src/transformations/op_conversions/convert_bitwise_to_logical_bool.cpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> * Update to REGISETR_PASS --------- Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
This commit is contained in:
committed by
GitHub
parent
142a72d0f0
commit
fdb22c8610
@@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class TRANSFORMATIONS_API ConvertBitwiseAndToLogicalAnd;
|
||||
class TRANSFORMATIONS_API ConvertBitwiseNotToLogicalNot;
|
||||
class TRANSFORMATIONS_API ConvertBitwiseOrToLogicalOr;
|
||||
class TRANSFORMATIONS_API ConvertBitwiseXorToLogicalXor;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::ConvertBitwiseAndToLogicalAnd : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertBitwiseAndToLogicalAnd", "0");
|
||||
ConvertBitwiseAndToLogicalAnd();
|
||||
};
|
||||
class ov::pass::ConvertBitwiseNotToLogicalNot : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertBitwiseNotToLogicalNot", "0");
|
||||
ConvertBitwiseNotToLogicalNot();
|
||||
};
|
||||
class ov::pass::ConvertBitwiseOrToLogicalOr : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertBitwiseOrToLogicalOr", "0");
|
||||
ConvertBitwiseOrToLogicalOr();
|
||||
};
|
||||
class ov::pass::ConvertBitwiseXorToLogicalXor : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertBitwiseXorToLogicalXor", "0");
|
||||
ConvertBitwiseXorToLogicalXor();
|
||||
};
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts Bitwise operators to Logical for boolean datatype for plugins that don't support opset13 Bitwise
|
||||
*/
|
||||
class ConvertBitwiseToLogical : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertBitwiseToLogical", "0");
|
||||
ConvertBitwiseToLogical() {
|
||||
add_matcher<ov::pass::ConvertBitwiseAndToLogicalAnd>();
|
||||
add_matcher<ov::pass::ConvertBitwiseNotToLogicalNot>();
|
||||
add_matcher<ov::pass::ConvertBitwiseOrToLogicalOr>();
|
||||
add_matcher<ov::pass::ConvertBitwiseXorToLogicalXor>();
|
||||
}
|
||||
};
|
||||
@@ -65,6 +65,7 @@
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
|
||||
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
|
||||
#include "transformations/op_conversions/convert_convertlike.hpp"
|
||||
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
|
||||
@@ -226,6 +227,11 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
||||
ADD_MATCHER(fq_fusions, MulFakeQuantizeFusion)
|
||||
fq_fusions->set_name("ov::pass::FakeQuantizeFusions");
|
||||
|
||||
// Temporary transformation to allow for PyTorch frontend to
|
||||
// partially support bitwise operators with boolean inputs for plugins
|
||||
// that didn't enabled BitwiseOps from opset13
|
||||
REGISTER_PASS(manager, ConvertBitwiseToLogical)
|
||||
|
||||
// StridesOptimization should be at the very end
|
||||
// because we cannot insert any MaxPools since they may prevent
|
||||
// other optimizations
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/bitwise_and.hpp"
|
||||
#include "openvino/op/bitwise_not.hpp"
|
||||
#include "openvino/op/bitwise_or.hpp"
|
||||
#include "openvino/op/bitwise_xor.hpp"
|
||||
#include "openvino/op/logical_and.hpp"
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/logical_or.hpp"
|
||||
#include "openvino/op/logical_xor.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
ov::pass::ConvertBitwiseAndToLogicalAnd::ConvertBitwiseAndToLogicalAnd() {
|
||||
MATCHER_SCOPE(ConvertBitwiseAndToLogicalAnd);
|
||||
auto pattern =
|
||||
pattern::wrap_type<ov::op::v13::BitwiseAnd>({pattern::any_input(pattern::type_matches(element::boolean)),
|
||||
pattern::any_input(pattern::type_matches(element::boolean))});
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto bitwise = std::dynamic_pointer_cast<ov::op::v13::BitwiseAnd>(m.get_match_root());
|
||||
if (!bitwise || transformation_callback(bitwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto logical = std::make_shared<ov::op::v1::LogicalAnd>(bitwise->input_value(0),
|
||||
bitwise->input_value(1),
|
||||
bitwise->get_autob());
|
||||
|
||||
logical->set_friendly_name(bitwise->get_friendly_name());
|
||||
copy_runtime_info(bitwise, logical);
|
||||
replace_node(bitwise, logical);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
ov::pass::ConvertBitwiseNotToLogicalNot::ConvertBitwiseNotToLogicalNot() {
|
||||
MATCHER_SCOPE(ConvertBitwiseNotToLogicalNot);
|
||||
auto pattern =
|
||||
pattern::wrap_type<ov::op::v13::BitwiseNot>({pattern::any_input(pattern::type_matches(element::boolean))});
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto bitwise = std::dynamic_pointer_cast<ov::op::v13::BitwiseNot>(m.get_match_root());
|
||||
if (!bitwise || transformation_callback(bitwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto logical = std::make_shared<ov::op::v1::LogicalNot>(bitwise->input_value(0));
|
||||
|
||||
logical->set_friendly_name(bitwise->get_friendly_name());
|
||||
copy_runtime_info(bitwise, logical);
|
||||
replace_node(bitwise, logical);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::ConvertBitwiseOrToLogicalOr::ConvertBitwiseOrToLogicalOr() {
|
||||
MATCHER_SCOPE(ConvertBitwiseOrToLogicalOr);
|
||||
auto pattern =
|
||||
pattern::wrap_type<ov::op::v13::BitwiseOr>({pattern::any_input(pattern::type_matches(element::boolean)),
|
||||
pattern::any_input(pattern::type_matches(element::boolean))});
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto bitwise = std::dynamic_pointer_cast<ov::op::v13::BitwiseOr>(m.get_match_root());
|
||||
if (!bitwise || transformation_callback(bitwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto logical = std::make_shared<ov::op::v1::LogicalOr>(bitwise->input_value(0),
|
||||
bitwise->input_value(1),
|
||||
bitwise->get_autob());
|
||||
|
||||
logical->set_friendly_name(bitwise->get_friendly_name());
|
||||
copy_runtime_info(bitwise, logical);
|
||||
replace_node(bitwise, logical);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::ConvertBitwiseXorToLogicalXor::ConvertBitwiseXorToLogicalXor() {
|
||||
MATCHER_SCOPE(ConvertBitwiseXorToLogicalXor);
|
||||
auto pattern =
|
||||
pattern::wrap_type<ov::op::v13::BitwiseXor>({pattern::any_input(pattern::type_matches(element::boolean)),
|
||||
pattern::any_input(pattern::type_matches(element::boolean))});
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto bitwise = std::dynamic_pointer_cast<ov::op::v13::BitwiseXor>(m.get_match_root());
|
||||
if (!bitwise || transformation_callback(bitwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto logical = std::make_shared<ov::op::v1::LogicalXor>(bitwise->input_value(0),
|
||||
bitwise->input_value(1),
|
||||
bitwise->get_autob());
|
||||
|
||||
logical->set_friendly_name(bitwise->get_friendly_name());
|
||||
copy_runtime_info(bitwise, logical);
|
||||
replace_node(bitwise, logical);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "openvino/opsets/opset13.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
using namespace ov;
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ov::Model> create_bitwise_model(std::string op_type, const ov::element::Type input_type) {
|
||||
const auto lhs = std::make_shared<ov::opset13::Parameter>(input_type, ov::Shape{1, 3, 100, 100});
|
||||
const auto rhs = std::make_shared<ov::opset13::Parameter>(input_type, ov::Shape{1, 3, 100, 100});
|
||||
|
||||
std::shared_ptr<ov::Node> bitwise;
|
||||
ParameterVector params{lhs, rhs};
|
||||
if (op_type == "and") {
|
||||
bitwise = std::make_shared<ov::opset13::BitwiseAnd>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
} else if (op_type == "not") {
|
||||
bitwise = std::make_shared<ov::opset13::BitwiseNot>(lhs);
|
||||
params = {lhs};
|
||||
} else if (op_type == "or") {
|
||||
bitwise = std::make_shared<ov::opset13::BitwiseOr>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
} else if (op_type == "xor") {
|
||||
bitwise = std::make_shared<ov::opset13::BitwiseXor>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
}
|
||||
|
||||
bitwise->set_friendly_name("bitwise");
|
||||
|
||||
return std::make_shared<ov::Model>(bitwise->outputs(), params);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> create_logical_model(std::string op_type) {
|
||||
const auto lhs = std::make_shared<ov::opset1::Parameter>(ov::element::boolean, ov::Shape{1, 3, 100, 100});
|
||||
const auto rhs = std::make_shared<ov::opset1::Parameter>(ov::element::boolean, ov::Shape{1, 3, 100, 100});
|
||||
std::shared_ptr<ov::Node> logical;
|
||||
ParameterVector params = {lhs, rhs};
|
||||
if (op_type == "and") {
|
||||
logical = std::make_shared<ov::opset1::LogicalAnd>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
} else if (op_type == "not") {
|
||||
logical = std::make_shared<ov::opset1::LogicalNot>(lhs);
|
||||
params = {lhs};
|
||||
} else if (op_type == "or") {
|
||||
logical = std::make_shared<ov::opset1::LogicalOr>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
} else if (op_type == "xor") {
|
||||
logical = std::make_shared<ov::opset1::LogicalXor>(lhs, rhs, op::AutoBroadcastType::NONE);
|
||||
}
|
||||
|
||||
logical->set_friendly_name("logical");
|
||||
|
||||
return std::make_shared<ov::Model>(logical->outputs(), params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_and_i32) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("and", element::i32);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_not_i32) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("not", element::i32);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_or_i32) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("or", element::i32);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_xor_i32) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("xor", element::i32);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_and_boolean) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("and", element::boolean);
|
||||
model_ref = create_logical_model("and");
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_not_boolean) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("not", element::boolean);
|
||||
model_ref = create_logical_model("not");
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_or_boolean) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("or", element::boolean);
|
||||
model_ref = create_logical_model("or");
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertBitwiseToLogical_xor_boolean) {
|
||||
auto transform = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
transform->add_matcher<ConvertBitwiseToLogical>();
|
||||
model = create_bitwise_model("xor", element::boolean);
|
||||
model_ref = create_logical_model("xor");
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
@@ -3,10 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/logical_and.hpp"
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/logical_or.hpp"
|
||||
#include "openvino/op/logical_xor.hpp"
|
||||
#include "openvino/op/bitwise_and.hpp"
|
||||
#include "openvino/op/bitwise_not.hpp"
|
||||
#include "openvino/op/bitwise_or.hpp"
|
||||
#include "openvino/op/bitwise_xor.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@@ -17,9 +17,7 @@ namespace op {
|
||||
OutputVector translate_bitwise_not(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
"aten::bitwise_not supported only for boolean input");
|
||||
auto not_x = context.mark_node(std::make_shared<ov::op::v1::LogicalNot>(x));
|
||||
auto not_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseNot>(x));
|
||||
if (!context.input_is_none(1)) {
|
||||
context.mutate_input(1, not_x);
|
||||
}
|
||||
@@ -27,32 +25,38 @@ OutputVector translate_bitwise_not(const NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_bitwise_and(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
"aten::bitwise_not supported only for boolean input");
|
||||
auto and_x = context.mark_node(std::make_shared<ov::op::v1::LogicalAnd>(x, y));
|
||||
align_eltwise_input_types(context, x, y, false);
|
||||
auto and_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseAnd>(x, y));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, and_x);
|
||||
}
|
||||
return {and_x};
|
||||
};
|
||||
|
||||
OutputVector translate_bitwise_or(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
"aten::bitwise_not supported only for boolean input");
|
||||
auto or_x = context.mark_node(std::make_shared<ov::op::v1::LogicalOr>(x, y));
|
||||
align_eltwise_input_types(context, x, y, false);
|
||||
auto or_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseOr>(x, y));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, or_x);
|
||||
}
|
||||
return {or_x};
|
||||
};
|
||||
|
||||
OutputVector translate_bitwise_xor(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
"aten::bitwise_xor supported only for boolean input");
|
||||
auto xor_x = context.mark_node(std::make_shared<ov::op::v1::LogicalXor>(x, y));
|
||||
align_eltwise_input_types(context, x, y, false);
|
||||
auto xor_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseXor>(x, y));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, xor_x);
|
||||
}
|
||||
return {xor_x};
|
||||
};
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ OP_CONVERTER(translate_batch_norm);
|
||||
OP_CONVERTER(translate_bitwise_and);
|
||||
OP_CONVERTER(translate_bitwise_not);
|
||||
OP_CONVERTER(translate_bitwise_or);
|
||||
OP_CONVERTER(translate_bitwise_xor);
|
||||
OP_CONVERTER(translate_cat);
|
||||
OP_CONVERTER(translate_cdist);
|
||||
OP_CONVERTER(translate_channel_shuffle);
|
||||
@@ -230,11 +231,11 @@ OP_CONVERTER(translate_transpose_fx);
|
||||
// Supported ops for TorchScript
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_and},
|
||||
{"aten::__and__", op::translate_bitwise_and},
|
||||
{"aten::__derive_index", op::translate_derive_index},
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__or__", op::translate_or},
|
||||
{"aten::__or__", op::translate_bitwise_or},
|
||||
{"aten::__xor__", op::translate_bitwise_xor},
|
||||
{"aten::__range_length", op::translate_range_length},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
@@ -280,7 +281,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::broadcast_to", op::translate_expand},
|
||||
{"aten::baddbmm", op::translate_addmm},
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
{"aten::bitwise_and", op::translate_bitwise_and},
|
||||
{"aten::bitwise_not", op::translate_bitwise_not},
|
||||
{"aten::bitwise_or", op::translate_bitwise_or},
|
||||
{"aten::bitwise_xor", op::translate_bitwise_xor},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::Bool", op::translate_bool},
|
||||
{"aten::cat", op::translate_cat},
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
|
||||
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
|
||||
#include "transformations/op_conversions/convert_batch_to_space.hpp"
|
||||
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
|
||||
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
|
||||
#include "transformations/op_conversions/convert_depth_to_space.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
|
||||
@@ -444,6 +445,11 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
CPU_ENABLE_PASS_COMMON(manager, ov::pass::ConvertDetectionOutput1ToDetectionOutput8);
|
||||
CPU_ENABLE_PASS_COMMON(manager, ov::pass::ConvertROIAlign3To9);
|
||||
|
||||
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertBitwiseAndToLogicalAnd);
|
||||
CPU_ENABLE_PASS_COMMON(manager, ov::pass::ConvertBitwiseNotToLogicalNot);
|
||||
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertBitwiseOrToLogicalOr);
|
||||
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertBitwiseXorToLogicalXor);
|
||||
|
||||
if (useLpt) {
|
||||
CPU_LPT_SCOPE(LowPrecisionTransformations_Part3);
|
||||
CPU_SET_CALLBACK_COMMON(manager,
|
||||
|
||||
@@ -9,13 +9,11 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestAnd(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model_tensor_input(self):
|
||||
class aten_and_tensor(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -25,10 +23,9 @@ class TestAnd(PytorchLayerTest):
|
||||
ref_net = None
|
||||
|
||||
return aten_and_tensor(), ref_net, "aten::__and__"
|
||||
|
||||
|
||||
def create_model_bool_input(self):
|
||||
class aten_and_bool(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -39,18 +36,43 @@ class TestAnd(PytorchLayerTest):
|
||||
|
||||
return aten_and_bool(), ref_net, "aten::__and__"
|
||||
|
||||
def create_model_int_input(self):
|
||||
class aten_and_int(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, int_a: int, int_b: int):
|
||||
return int_a & int_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_and_int(), ref_net, "aten::__and__"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_and_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([True, False, False], dtype=np.bool_), np.array(
|
||||
[True, True, False], dtype=np.bool_))
|
||||
self._test(*self.create_model_tensor_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (
|
||||
np.array([True, False, False], dtype=np.bool_),
|
||||
np.array([True, True, False], dtype=np.bool_),
|
||||
)
|
||||
self._test(*self.create_model_tensor_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_and_bool(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(True, dtype=np.bool_),
|
||||
np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (np.array(True, dtype=np.bool_), np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_and_int(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(3, dtype=np.int32), np.array(4, dtype=np.int32))
|
||||
self._test(*self.create_model_int_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_and_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([3, 5, 8], dtype=np.int32), np.array([7, 11, 2], dtype=np.int32))
|
||||
self._test(
|
||||
*self.create_model_tensor_input(), ie_device, precision, ir_version, freeze_model=False, trace_model=True
|
||||
)
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestBitwiseNot(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return ((np.random.randn(1, 5) > 0).astype(bool),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_bitwise_not(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return torch.bitwise_not(x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_bitwise_not(), ref_net, "aten::bitwise_not"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_bitwise_not(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
132
tests/layer_tests/pytorch_tests/test_bitwise_ops.py
Normal file
132
tests/layer_tests/pytorch_tests/test_bitwise_ops.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestBitwiseOp(PytorchLayerTest):
|
||||
def _prepare_input(self, out, unary, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape):
|
||||
choices = np.array([0, 1, 255, 7])
|
||||
x = np.random.choice(choices, lhs_shape).astype(lhs_dtype)
|
||||
if unary:
|
||||
return (x,) if not out else (x, np.zeros_like(x).astype(lhs_dtype))
|
||||
y = np.random.choice(choices, rhs_shape).astype(rhs_dtype)
|
||||
if not out:
|
||||
return x, y
|
||||
return x, y, np.zeros_like(x).astype(lhs_dtype) + np.zeros_like(y).astype(rhs_dtype)
|
||||
|
||||
def create_model(self, op_name, out):
|
||||
ops = {
|
||||
"and": torch.bitwise_and,
|
||||
"or": torch.bitwise_or,
|
||||
"xor": torch.bitwise_xor,
|
||||
"not": torch.bitwise_not,
|
||||
}
|
||||
op = ops[op_name]
|
||||
|
||||
class aten_bitwise(torch.nn.Module):
|
||||
def __init__(self, op, out) -> None:
|
||||
super().__init__()
|
||||
self.op = op
|
||||
if op == torch.bitwise_not:
|
||||
self.forward = self.forward_not
|
||||
if out:
|
||||
self.forward = self.forward_out if not op == torch.bitwise_not else self.forward_not_out
|
||||
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
return self.op(tensor_a, tensor_b)
|
||||
|
||||
def forward_out(self, tensor_a, tensor_b, out):
|
||||
return self.op(tensor_a, tensor_b, out=out), out
|
||||
|
||||
def forward_not(self, tensor_a):
|
||||
return self.op(tensor_a)
|
||||
|
||||
def forward_not_out(self, tensor_a, out):
|
||||
return self.op(tensor_a, out=out), out
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_bitwise(op, out), ref_net, f"aten::bitwise_{op_name}"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"])
|
||||
@pytest.mark.parametrize("lhs_dtype", ["bool", "int32", "uint8", "int64"])
|
||||
@pytest.mark.parametrize("rhs_dtype", ["bool", "int32", "uint8", "int64"])
|
||||
@pytest.mark.parametrize(
|
||||
("lhs_shape", "rhs_shape"),
|
||||
[
|
||||
([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("out", [False, True])
|
||||
def test_bitwise_mixed_dtypes(
|
||||
self, op_type, out, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape, ie_device, precision, ir_version
|
||||
):
|
||||
self._test(
|
||||
*self.create_model(op_type, out),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input={
|
||||
"out": out,
|
||||
"unary": op_type == "not",
|
||||
"lhs_dtype": lhs_dtype,
|
||||
"rhs_dtype": rhs_dtype,
|
||||
"lhs_shape": lhs_shape,
|
||||
"rhs_shape": rhs_shape,
|
||||
},
|
||||
freeze_model=False,
|
||||
trace_model=True,
|
||||
)
|
||||
|
||||
|
||||
class TestBitwiseOperators(PytorchLayerTest):
|
||||
def _prepare_input(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape):
|
||||
choices = np.array([0, 1, 255, 7])
|
||||
x = np.random.choice(choices, lhs_shape).astype(lhs_dtype)
|
||||
y = np.random.choice(choices, rhs_shape).astype(rhs_dtype)
|
||||
return x, y
|
||||
|
||||
def create_model(self):
|
||||
class aten_bitwise(torch.nn.Module):
|
||||
def forward(self, lhs, rhs):
|
||||
return lhs & rhs, ~lhs, lhs | rhs, lhs ^ rhs
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_bitwise(), ref_net, ("aten::__and__", "aten::bitwise_not", "aten::__or__", "aten::__xor__")
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("lhs_dtype", ["bool", "int32"])
|
||||
@pytest.mark.parametrize("rhs_dtype", ["bool", "int32"])
|
||||
@pytest.mark.parametrize(
|
||||
("lhs_shape", "rhs_shape"),
|
||||
[
|
||||
([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
],
|
||||
)
|
||||
def test_bitwise_operators(self, lhs_dtype, rhs_dtype, lhs_shape, rhs_shape, ie_device, precision, ir_version):
|
||||
self._test(
|
||||
*self.create_model(),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input={
|
||||
"lhs_dtype": lhs_dtype,
|
||||
"rhs_dtype": rhs_dtype,
|
||||
"lhs_shape": lhs_shape,
|
||||
"rhs_shape": rhs_shape,
|
||||
},
|
||||
trace_model=True,
|
||||
freeze_model=False,
|
||||
)
|
||||
@@ -1,29 +1,78 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestLog(PytorchLayerTest):
|
||||
class TestOr(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randint(0, 255, (20, 30, 40, 50)),)
|
||||
return self.input_data
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
def create_model_tensor_input(self):
|
||||
class aten_or_tensor(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
class aten_or(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
res = torch.ByteTensor(x.size()).zero_()
|
||||
res[:, :, :, 1:] = res[:, :, :, 1:] | (x[:, :, :, 1:] != x[:, :, :, :-1])
|
||||
res[:, :, :, :-1] = res[:, :, :, :-1] | (x[:, :, :, 1:] != x[:, :, :, :-1])
|
||||
return res.float()
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
return tensor_a | tensor_b
|
||||
|
||||
return aten_or(), None, "aten::__or__"
|
||||
ref_net = None
|
||||
|
||||
return aten_or_tensor(), ref_net, "aten::__or__"
|
||||
|
||||
def create_model_bool_input(self):
|
||||
class aten_or_bool(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, bool_a: bool, bool_b: bool):
|
||||
return bool_a | bool_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_or_bool(), ref_net, "aten::__or__"
|
||||
|
||||
def create_model_int_input(self):
|
||||
class aten_or_int(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, int_a: int, int_b: int):
|
||||
return int_a | int_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_or_int(), ref_net, "aten::__or__"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_or(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||
dynamic_shapes=False, trace_model=True, use_convert_model=True)
|
||||
def test_or_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (
|
||||
np.array([True, False, False], dtype=np.bool_),
|
||||
np.array([True, True, False], dtype=np.bool_),
|
||||
)
|
||||
self._test(*self.create_model_tensor_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_or_bool(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(True, dtype=np.bool_), np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_or_int(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(3, dtype=np.int32), np.array(4, dtype=np.int32))
|
||||
self._test(*self.create_model_int_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_or_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([3, 5, 8], dtype=np.int32), np.array([7, 11, 2], dtype=np.int32))
|
||||
self._test(
|
||||
*self.create_model_tensor_input(), ie_device, precision, ir_version, freeze_model=False, trace_model=True
|
||||
)
|
||||
|
||||
@@ -9,13 +9,11 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestXor(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model_tensor_input(self):
|
||||
class aten_xor_tensor(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -28,7 +26,6 @@ class TestXor(PytorchLayerTest):
|
||||
|
||||
def create_model_bool_input(self):
|
||||
class aten_xor_bool(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -41,7 +38,6 @@ class TestXor(PytorchLayerTest):
|
||||
|
||||
def create_model_int_input(self):
|
||||
class aten_xor_int(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -55,33 +51,28 @@ class TestXor(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([True, False, False], dtype=np.bool_), np.array(
|
||||
[True, True, False], dtype=np.bool_))
|
||||
self._test(*self.create_model_tensor_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (
|
||||
np.array([True, False, False], dtype=np.bool_),
|
||||
np.array([True, True, False], dtype=np.bool_),
|
||||
)
|
||||
self._test(*self.create_model_tensor_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_bool(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(True, dtype=np.bool_),
|
||||
np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (np.array(True, dtype=np.bool_), np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_int(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(3, dtype=np.int),
|
||||
np.array(4, dtype=np.int))
|
||||
self._test(*self.create_model_int_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (np.array(3, dtype=np.int32), np.array(4, dtype=np.int32))
|
||||
self._test(*self.create_model_int_input(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([3, 5, 8], dtype=np.int), np.array(
|
||||
[7, 11, 2], dtype=np.int))
|
||||
self._test(*self.create_model_tensor_input(),
|
||||
ie_device, precision, ir_version)
|
||||
self.input_data = (np.array([3, 5, 8], dtype=np.int32), np.array([7, 11, 2], dtype=np.int32))
|
||||
self._test(
|
||||
*self.create_model_tensor_input(), ie_device, precision, ir_version, freeze_model=False, trace_model=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user