[core]Migrate GridSample operator to new API (#20852)
* MIgrate GridSample to new API * Refactor GridSample to reduce binary size - use function pointer instead std::function (simpler less code size) - use RoundingGuard instead manual set/restore rounding mode - move interpolate selection outside main data processing loop
This commit is contained in:
parent
ae343a0178
commit
7d74dac3ee
@ -59,9 +59,7 @@ public:
|
||||
m_attributes = attributes;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
private:
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "openvino/op/grid_sample.hpp"
|
||||
#include "openvino/reference/rounding_guard.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
@ -20,10 +21,20 @@ namespace {
|
||||
using index_4D_t = typename std::array<size_t, 4>;
|
||||
|
||||
template <typename GRID_ET>
|
||||
using denormalize_fn_t = typename std::function<GRID_ET(GRID_ET, size_t)>;
|
||||
using denormalize_fn_t = GRID_ET (*)(GRID_ET, size_t);
|
||||
|
||||
template <typename DATA_ET>
|
||||
using get_padded_fn_t = typename std::function<DATA_ET(const DATA_ET*, const Shape&, size_t, size_t, long, long)>;
|
||||
using get_padded_fn_t = DATA_ET (*)(const DATA_ET*, const Shape&, size_t, size_t, long, long);
|
||||
|
||||
template <typename DATA_ET, typename GRID_ET>
|
||||
using interpolate_fn_t = DATA_ET (*)(const DATA_ET* data,
|
||||
const Shape&,
|
||||
const size_t n,
|
||||
const size_t c,
|
||||
const GRID_ET,
|
||||
const GRID_ET,
|
||||
const get_padded_fn_t<DATA_ET>&,
|
||||
const denormalize_fn_t<GRID_ET>&);
|
||||
|
||||
template <typename T>
|
||||
T& get_single_value(T* buffer, const Shape& shape, const index_4D_t& index) {
|
||||
@ -240,8 +251,7 @@ void grid_sample(DATA_ET* output,
|
||||
const auto W_out = grid_shape[2];
|
||||
const Shape output_shape{N, C, H_out, W_out};
|
||||
|
||||
const auto prev_rounding_mode = std::fegetround();
|
||||
std::fesetround(FE_TONEAREST);
|
||||
const RoundingGuard rounding_guard{FE_TONEAREST};
|
||||
|
||||
get_padded_fn_t<DATA_ET> get_padded_fn;
|
||||
switch (padding_mode) {
|
||||
@ -253,18 +263,25 @@ void grid_sample(DATA_ET* output,
|
||||
get_padded_fn = border_padding<DATA_ET>;
|
||||
break;
|
||||
case ov::op::v9::GridSample::PaddingMode::REFLECTION:
|
||||
if (align_corners)
|
||||
get_padded_fn = reflection_data_with_align<DATA_ET>;
|
||||
else
|
||||
get_padded_fn = reflection_data_no_align<DATA_ET>;
|
||||
get_padded_fn = align_corners ? reflection_data_with_align<DATA_ET> : reflection_data_no_align<DATA_ET>;
|
||||
break;
|
||||
}
|
||||
|
||||
denormalize_fn_t<GRID_ET> denormalize_fn;
|
||||
if (align_corners)
|
||||
denormalize_fn = rescale_align<GRID_ET>;
|
||||
else
|
||||
denormalize_fn = rescale_noalign<GRID_ET>;
|
||||
const auto denormalize_fn = align_corners ? rescale_align<GRID_ET> : rescale_noalign<GRID_ET>;
|
||||
|
||||
interpolate_fn_t<DATA_ET, GRID_ET> interpolate_fn;
|
||||
switch (interpolation_mode) {
|
||||
default:
|
||||
case ov::op::v9::GridSample::InterpolationMode::BILINEAR:
|
||||
interpolate_fn = bilinear<DATA_ET, GRID_ET>;
|
||||
break;
|
||||
case ov::op::v9::GridSample::InterpolationMode::NEAREST:
|
||||
interpolate_fn = nearest<DATA_ET, GRID_ET>;
|
||||
break;
|
||||
case ov::op::v9::GridSample::InterpolationMode::BICUBIC:
|
||||
interpolate_fn = bicubic<DATA_ET, GRID_ET>;
|
||||
break;
|
||||
}
|
||||
|
||||
for (size_t n = 0; n < N; ++n) {
|
||||
for (size_t c = 0; c < C; ++c) {
|
||||
@ -274,24 +291,11 @@ void grid_sample(DATA_ET* output,
|
||||
const auto x_n = get_single_value(grid, grid_shape, index_4D_t{n, y, x, 0});
|
||||
|
||||
auto& out = get_single_value(output, output_shape, index_4D_t{n, c, y, x});
|
||||
|
||||
switch (interpolation_mode) {
|
||||
case ov::op::v9::GridSample::InterpolationMode::BILINEAR:
|
||||
out = bilinear(data, data_shape, n, c, y_n, x_n, get_padded_fn, denormalize_fn);
|
||||
break;
|
||||
case ov::op::v9::GridSample::InterpolationMode::NEAREST:
|
||||
out = nearest(data, data_shape, n, c, y_n, x_n, get_padded_fn, denormalize_fn);
|
||||
break;
|
||||
case ov::op::v9::GridSample::InterpolationMode::BICUBIC:
|
||||
out = bicubic(data, data_shape, n, c, y_n, x_n, get_padded_fn, denormalize_fn);
|
||||
break;
|
||||
}
|
||||
out = interpolate_fn(data, data_shape, n, c, y_n, x_n, get_padded_fn, denormalize_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::fesetround(prev_rounding_mode);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
@ -4,19 +4,67 @@
|
||||
|
||||
#include "openvino/op/grid_sample.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "grid_sample_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/reference/grid_sample.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
namespace ov {
|
||||
op::v9::GridSample::GridSample(const Output<Node>& data, const Output<Node>& grid, const Attributes& attributes)
|
||||
namespace op {
|
||||
namespace v9 {
|
||||
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(Tensor& output,
|
||||
const Tensor& data,
|
||||
const Tensor& grid,
|
||||
const Shape& data_shape,
|
||||
const Shape& grid_shape,
|
||||
const GridSample::Attributes& attributes) {
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<f32>::apply<EvalByGridType>(grid.get_element_type(),
|
||||
output.data<T>(),
|
||||
data.data<const T>(),
|
||||
grid,
|
||||
data_shape,
|
||||
grid_shape,
|
||||
attributes);
|
||||
}
|
||||
|
||||
private:
|
||||
struct EvalByGridType : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t ET, class T, class G = fundamental_type_for<ET>>
|
||||
static result_type visit(T* output,
|
||||
const T* data,
|
||||
const Tensor& grid,
|
||||
const Shape& data_shape,
|
||||
const Shape& grid_shape,
|
||||
const GridSample::Attributes& attributes) {
|
||||
reference::grid_sample(output,
|
||||
data,
|
||||
grid.data<const G>(),
|
||||
data_shape,
|
||||
grid_shape,
|
||||
attributes.align_corners,
|
||||
attributes.mode,
|
||||
attributes.padding_mode);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
GridSample::GridSample(const Output<Node>& data, const Output<Node>& grid, const Attributes& attributes)
|
||||
: op::Op{{data, grid}},
|
||||
m_attributes{attributes} {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v9::GridSample::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool GridSample::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v9_GridSample_visit_attributes);
|
||||
visitor.on_attribute("align_corners", m_attributes.align_corners);
|
||||
visitor.on_attribute("mode", m_attributes.mode);
|
||||
@ -24,7 +72,7 @@ bool op::v9::GridSample::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v9::GridSample::validate_and_infer_types() {
|
||||
void GridSample::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v9_GridSample_validate_and_infer_types);
|
||||
if (!get_input_element_type(1).is_dynamic()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
@ -39,12 +87,36 @@ void op::v9::GridSample::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), out_shapes[0]);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v9::GridSample::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> GridSample::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v9_GridSample_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<op::v9::GridSample>(new_args.at(0), new_args.at(1), this->get_attributes());
|
||||
return std::make_shared<GridSample>(new_args.at(0), new_args.at(1), get_attributes());
|
||||
}
|
||||
|
||||
bool GridSample::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v9_GridSample_evaluate);
|
||||
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
const auto& out_shape = shape_infer(this, ov::util::get_tensors_partial_shapes(inputs)).front().to_shape();
|
||||
outputs[0].set_shape(out_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<f32>::apply<Evaluate>(inputs[0].get_element_type(),
|
||||
outputs[0],
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[0].get_shape(),
|
||||
inputs[1].get_shape(),
|
||||
m_attributes);
|
||||
}
|
||||
|
||||
bool GridSample::has_evaluate() const {
|
||||
return get_input_element_type(0) == element::f32 && get_input_element_type(1) == element::f32;
|
||||
}
|
||||
} // namespace v9
|
||||
} // namespace op
|
||||
|
||||
std::ostream& operator<<(std::ostream& s, const op::v9::GridSample::InterpolationMode& mode) {
|
||||
return s << as_string(mode);
|
||||
}
|
||||
@ -54,7 +126,7 @@ std::ostream& operator<<(std::ostream& s, const op::v9::GridSample::PaddingMode&
|
||||
}
|
||||
|
||||
template <>
|
||||
NGRAPH_API EnumNames<op::v9::GridSample::InterpolationMode>& EnumNames<op::v9::GridSample::InterpolationMode>::get() {
|
||||
OPENVINO_API EnumNames<op::v9::GridSample::InterpolationMode>& EnumNames<op::v9::GridSample::InterpolationMode>::get() {
|
||||
static auto enum_names =
|
||||
EnumNames<op::v9::GridSample::InterpolationMode>("op::v9::GridSample::InterpolationMode",
|
||||
{{"bilinear", op::v9::GridSample::InterpolationMode::BILINEAR},
|
||||
@ -64,7 +136,7 @@ NGRAPH_API EnumNames<op::v9::GridSample::InterpolationMode>& EnumNames<op::v9::G
|
||||
}
|
||||
|
||||
template <>
|
||||
NGRAPH_API EnumNames<op::v9::GridSample::PaddingMode>& EnumNames<op::v9::GridSample::PaddingMode>::get() {
|
||||
OPENVINO_API EnumNames<op::v9::GridSample::PaddingMode>& EnumNames<op::v9::GridSample::PaddingMode>::get() {
|
||||
static auto enum_names =
|
||||
EnumNames<op::v9::GridSample::PaddingMode>("op::v9::GridSample::PaddingMode",
|
||||
{{"zeros", op::v9::GridSample::PaddingMode::ZEROS},
|
||||
@ -72,73 +144,4 @@ NGRAPH_API EnumNames<op::v9::GridSample::PaddingMode>& EnumNames<op::v9::GridSam
|
||||
{"reflection", op::v9::GridSample::PaddingMode::REFLECTION}});
|
||||
return enum_names;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace {
|
||||
|
||||
template <element::Type_t DATA_ET, element::Type_t GRID_ET>
|
||||
bool evaluate_exec(const ngraph::HostTensorPtr& output,
|
||||
const ngraph::HostTensorPtr& data,
|
||||
const ngraph::HostTensorPtr& grid,
|
||||
const op::v9::GridSample::Attributes& attributes) {
|
||||
ov::reference::grid_sample(output->get_data_ptr<DATA_ET>(),
|
||||
data->get_data_ptr<DATA_ET>(),
|
||||
grid->get_data_ptr<GRID_ET>(),
|
||||
data->get_shape(),
|
||||
grid->get_shape(),
|
||||
attributes.align_corners,
|
||||
attributes.mode,
|
||||
attributes.padding_mode);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define GRID_SAMPLE_TYPE_CASE(a, ...) \
|
||||
case element::Type_t::a: { \
|
||||
OV_OP_SCOPE(OV_PP_CAT3(evaluate_exec_grid_sample, _, a)); \
|
||||
rc = evaluate_exec<DATA_ET, element::Type_t::a>(__VA_ARGS__); \
|
||||
} break
|
||||
|
||||
template <element::Type_t DATA_ET>
|
||||
bool evaluate(const ngraph::HostTensorPtr& output,
|
||||
const ngraph::HostTensorPtr& data,
|
||||
const ngraph::HostTensorPtr& grid,
|
||||
const op::v9::GridSample::Attributes& attributes) {
|
||||
auto rc = true;
|
||||
switch (grid->get_element_type()) {
|
||||
GRID_SAMPLE_TYPE_CASE(f32, output, data, grid, attributes);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool evaluate_grid_sample(const ngraph::HostTensorPtr& output,
|
||||
const ngraph::HostTensorPtr& data,
|
||||
const ngraph::HostTensorPtr& grid,
|
||||
const op::v9::GridSample::Attributes& attributes) {
|
||||
auto rc = true;
|
||||
switch (output->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_grid_sample, f32, output, data, grid, attributes);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool op::v9::GridSample::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v9_GridSample_evaluate);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(inputs, 2), "Invalid GridSample input TensorVector.");
|
||||
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(outputs, 1), "Invalid GridSample output TensorVector.");
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
return evaluate_grid_sample(outputs[0], inputs[0], inputs[1], m_attributes);
|
||||
}
|
||||
|
||||
bool op::v9::GridSample::has_evaluate() const {
|
||||
return get_input_element_type(0) == element::f32 && get_input_element_type(1) == element::f32;
|
||||
}
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user