[Common Frontend] Move ComplexTypeMark into common frontend (#21409)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-12-01 12:12:19 +04:00 committed by GitHub
parent abfbdd1b96
commit 9ecebdd202
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 36 deletions

View File

@ -0,0 +1,49 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/util/framework_node.hpp"
namespace ov {
namespace frontend {
// ComplexTypeMark serves to mark places that require complex type propagation
// that means to represent native complex type with simulating floating-point tensor
// that has one extra dimension to concatenate real and imaginary parts of complex tensor.
// For example, a tensor of complex type with shape [N1, N2, ..., Nk] will be transformed
// into a floating-point tensor [N1, N2, ..., Nk, 2]
// where a slice with index [..., 0] represents a real part and
// a slice with index [..., 1] represents a imaginary part.
class ComplexTypeMark : public ov::op::util::FrameworkNode {
public:
OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode);
ComplexTypeMark(const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type)
: ov::op::util::FrameworkNode(ov::OutputVector{input}, 1),
m_complex_part_type(complex_part_type) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_type(0, ov::element::dynamic, PartialShape::dynamic());
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
auto complex_type_mark = std::make_shared<ComplexTypeMark>(inputs[0], m_complex_part_type);
complex_type_mark->set_attrs(get_attrs());
return complex_type_mark;
}
ov::element::Type get_complex_part_type() const {
return m_complex_part_type;
}
private:
ov::element::Type m_complex_part_type;
};
} // namespace frontend
} // namespace ov

View File

@ -4,47 +4,13 @@
#pragma once
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/frontend/complex_type_mark.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
// ComplexTypeMark serves to mark places that require complex type propagation
// that means to represent native complex type with simulating floating-point tensor
// that has one extra dimension to concatenate real and imaginary parts of complex tensor.
// For example, a tensor of complex type with shape [N1, N2, ..., Nk] will be transformed
// into a floating-point tensor [N1, N2, ..., Nk, 2]
// where a slice with index [..., 0] represents a real part and
// a slice with index [..., 1] represents a imaginary part.
class ComplexTypeMark : public ov::op::util::FrameworkNode {
public:
OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode);
ComplexTypeMark(const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type)
: ov::op::util::FrameworkNode(ov::OutputVector{input}, 1),
m_complex_part_type(complex_part_type) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_type(0, ov::element::dynamic, PartialShape::dynamic());
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
auto complex_type_mark = std::make_shared<ComplexTypeMark>(inputs[0], m_complex_part_type);
complex_type_mark->set_attrs(get_attrs());
return complex_type_mark;
}
ov::element::Type get_complex_part_type() const {
return m_complex_part_type;
}
private:
ov::element::Type m_complex_part_type;
};
using ov::frontend::ComplexTypeMark;
} // namespace tensorflow
} // namespace frontend