LogicalNot: convert precision (#17061)
* CVS-108362 LogicalNot: convert precision * Test
This commit is contained in:
committed by
GitHub
parent
2e3deb8d8f
commit
cd4c012f08
@@ -84,14 +84,12 @@ bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, const preci
|
||||
const auto& to = it->second;
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to);
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, 1);
|
||||
for (size_t i = 0; i < node->get_input_size(); ++i)
|
||||
type_relaxed->set_origin_input_type(ov::element::boolean, i);
|
||||
return true;
|
||||
} else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
|
||||
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(
|
||||
*casted,
|
||||
ov::element::TypeVector{ov::element::boolean, ov::element::boolean},
|
||||
ov::element::TypeVector{to});
|
||||
ov::element::TypeVector input_types(node->get_input_size(), ov::element::boolean);
|
||||
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(*casted, input_types, ov::element::TypeVector{to});
|
||||
replace_node(node, relaxed_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -671,6 +671,14 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
||||
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
|
||||
std::shared_ptr<ov::op::TypeRelaxedBase> tr;
|
||||
for (const auto& node : f->get_ordered_ops())
|
||||
if (auto op = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node))
|
||||
tr = op;
|
||||
ASSERT_TRUE(tr != nullptr);
|
||||
ASSERT_EQ(tr->get_origin_input_type(0), element::boolean);
|
||||
ASSERT_EQ(tr->get_origin_input_type(1), element::undefined);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_Select) {
|
||||
|
||||
Reference in New Issue
Block a user