Files
openvino/src/common/legacy/src/ngraph_ops/interp.cpp
Ilya Churaev d5d67b2be8 Removed using of BWDCMP macro (#14292)
* Removed using of BWDCMP macro

* Fixed code style

* Removed type info

* Fixed ONNX tests

* Fixed CPU tests

* Fixed ExecutionNode
2022-12-07 12:47:16 +04:00

121 lines
4.3 KiB
C++
Executable File

// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "legacy/ngraph_ops/interp.hpp"
#include <limits>
#include <memory>
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
op::Interp::Interp(const Output<Node>& image, const InterpolateIEAttrs& attrs)
: Op({image}), m_attrs(attrs) {
constructor_validate_and_infer_types();
}
void op::Interp::validate_and_infer_types() {
if (get_input_partial_shape(0).is_static()) {
Shape input_shape = get_input_partial_shape(0).to_shape();
Shape output_shape(4);
// Assumes {N, C, H, W}
output_shape[0] = input_shape[0];
output_shape[1] = input_shape[1];
auto is_zero = [](float value) {
return std::fabs(value) < std::numeric_limits<float>::epsilon();
};
bool should_scale =
!(is_zero(m_attrs.zoom_factor) && is_zero(m_attrs.shrink_factor) && is_zero(m_attrs.scale_factor));
if (should_scale) {
float scale = m_attrs.scale_factor;
if (!is_zero(m_attrs.shrink_factor) || !is_zero(m_attrs.zoom_factor)) {
if (!is_zero(m_attrs.zoom_factor)) {
scale = m_attrs.zoom_factor;
}
if (!is_zero(m_attrs.shrink_factor)) {
scale /= m_attrs.shrink_factor;
}
}
output_shape[2] = static_cast<Shape::value_type>(input_shape[2] * scale);
output_shape[3] = static_cast<Shape::value_type>(input_shape[3] * scale);
}
if (m_attrs.height > 0) {
output_shape[2] = m_attrs.height;
}
if (m_attrs.width > 0) {
output_shape[3] = m_attrs.width;
}
set_output_type(0, get_input_element_type(0), output_shape);
} else {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
shared_ptr<Node> op::Interp::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<Interp>(new_args.at(0), m_attrs);
}
bool op::Interp::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("align_corners", m_attrs.align_corners);
visitor.on_attribute("width", m_attrs.width);
visitor.on_attribute("height", m_attrs.height);
visitor.on_attribute("pad_beg", m_attrs.pad_beg);
visitor.on_attribute("pad_end", m_attrs.pad_end);
return true;
}
op::ResampleV2::ResampleV2(const Output<Node>& image, const Output<Node>& output_shape,
const ResampleIEAttrs& attrs)
: Op({image, output_shape}), m_attrs(attrs) {
constructor_validate_and_infer_types();
}
op::ResampleV2::ResampleV2(const Output<Node>& image, const ResampleIEAttrs& attrs)
: Op({image}), m_attrs(attrs) {
constructor_validate_and_infer_types();
}
void op::ResampleV2::validate_and_infer_types() {
if (m_attrs.factor != 0) {
Shape output_shape(get_input_shape(0));
for (size_t i = 2; i < output_shape.size(); ++i) {
output_shape[i] *= m_attrs.factor;
}
set_output_type(0, get_input_element_type(0), output_shape);
} else if (auto const_shape = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())) {
NODE_VALIDATION_CHECK(this, shape_size(const_shape->get_shape()) == 4 || shape_size(const_shape->get_shape()) == 5,
"Layer shape must have rank 4 or 5", const_shape->get_shape());
auto out_shape = const_shape->cast_vector<int64_t>();
Shape output_shape;
for (size_t i = 0; i < const_shape->get_shape()[0]; i++) {
output_shape.push_back((out_shape[i] > 0) ? out_shape[i] : 0);
}
set_output_type(0, get_input_element_type(0), output_shape);
} else {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
shared_ptr<Node> op::ResampleV2::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<ResampleV2>(new_args.at(0), new_args.at(1), m_attrs);
}
bool op::ResampleV2::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("antialias", m_attrs.antialias);
visitor.on_attribute("factor", m_attrs.factor);
visitor.on_attribute("mode", m_attrs.mode);
return true;
}