LogicalNot: convert precision (#17061)

* CVS-108362 LogicalNot: convert precision

* Test
This commit is contained in:
Evgenya Stepyreva
2023-04-25 16:43:44 +04:00
committed by GitHub
parent 2e3deb8d8f
commit cd4c012f08
2 changed files with 12 additions and 6 deletions

View File

@@ -84,14 +84,12 @@ bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, const preci
const auto& to = it->second;
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
type_relaxed->set_overridden_output_type(to);
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
type_relaxed->set_origin_input_type(ov::element::boolean, 1);
for (size_t i = 0; i < node->get_input_size(); ++i)
type_relaxed->set_origin_input_type(ov::element::boolean, i);
return true;
} else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(
*casted,
ov::element::TypeVector{ov::element::boolean, ov::element::boolean},
ov::element::TypeVector{to});
ov::element::TypeVector input_types(node->get_input_size(), ov::element::boolean);
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(*casted, input_types, ov::element::TypeVector{to});
replace_node(node, relaxed_op);
return true;
}

View File

@@ -671,6 +671,14 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
std::shared_ptr<ov::op::TypeRelaxedBase> tr;
for (const auto& node : f->get_ordered_ops())
if (auto op = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node))
tr = op;
ASSERT_TRUE(tr != nullptr);
ASSERT_EQ(tr->get_origin_input_type(0), element::boolean);
ASSERT_EQ(tr->get_origin_input_type(1), element::undefined);
}
TEST(TransformationTests, ConvertPrecision_Select) {