【PaddlePaddle Hackathon 3】Add Paddle p_norm operator (#12299)
* add paddle_p_norm * remove cout * modify ngraph to ov Co-authored-by: Bo Liu <bo4.liu@intel.com>
This commit is contained in:
parent
82c0d58140
commit
b9e2f28850
@ -282,6 +282,14 @@ static const std::vector<std::string> models{
|
||||
std::string("nearest_downsample_false_1"),
|
||||
std::string("nearest_upsample_false_0"),
|
||||
std::string("nearest_upsample_false_1"),
|
||||
std::string("p_norm1"),
|
||||
std::string("p_norm2"),
|
||||
std::string("p_norm3"),
|
||||
std::string("p_norm4"),
|
||||
std::string("p_norm5"),
|
||||
std::string("p_norm6"),
|
||||
std::string("p_norm7"),
|
||||
std::string("p_norm8"),
|
||||
std::string("pad3d_test1"),
|
||||
std::string("pad3d_test2"),
|
||||
std::string("pad3d_test3"),
|
||||
|
@ -0,0 +1,87 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
# p_norm paddle model generator
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
|
||||
from save_model import saveModel
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
def p_norm_ref(x, p=None, axis=None, epsilon=1e-12, keepdim=None, name=None):
|
||||
attrs = {
|
||||
'axis': axis,
|
||||
'porder': p,
|
||||
'keepdim': keepdim,
|
||||
'epsilon': epsilon,
|
||||
}
|
||||
helper = LayerHelper('p_norm', **locals())
|
||||
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
|
||||
helper.append_op(type='p_norm',
|
||||
inputs={'X': x},
|
||||
outputs={'Out': out},
|
||||
attrs=attrs)
|
||||
return out
|
||||
|
||||
|
||||
def p_norm(name: str, x, axis, p, keepdim):
|
||||
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
|
||||
node_x = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
|
||||
|
||||
out = p_norm_ref(node_x, axis=axis, p=p, keepdim=keepdim)
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
input_shape = (2, 3, 4, 5, 6)
|
||||
input_data = np.random.rand(*input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm1', input_data, axis=4, p=1.5, keepdim=True)
|
||||
|
||||
input_shape = (3, 4, 5)
|
||||
input_data = np.random.rand(*input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm2', input_data, axis=0, p=0.0, keepdim=None)
|
||||
|
||||
input_shape = (4, 5, 6)
|
||||
input_data = np.random.rand(*input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm3', input_data, axis=None, p=None, keepdim=True)
|
||||
|
||||
input_shape = (6, 3, 4)
|
||||
input_data = np.random.rand(*input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm4', input_data, axis=1, p=float('inf'), keepdim=False)
|
||||
|
||||
input_shape = (3, 5, 6)
|
||||
input_data = np.random.rand(*input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm5', input_data, axis=1, p=float('-inf'), keepdim=True)
|
||||
|
||||
input_shape = (3, 6, 7)
|
||||
input_data = np.zeros(input_shape).astype(np.float32)
|
||||
paddle_result = p_norm('p_norm6', input_data, axis=0, p=0.0, keepdim=None)
|
||||
|
||||
input_shape = (10)
|
||||
input_data = np.random.rand(input_shape).astype("float32")
|
||||
input_data[0:10:2] = 0
|
||||
paddle_result = p_norm('p_norm7', input_data, axis=0, p=0.0, keepdim=False)
|
||||
|
||||
input_data = np.array([[0, 1, 2, -10]]).astype("float32")
|
||||
paddle_result = p_norm('p_norm8', input_data, axis=1, p=0.0, keepdim=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
63
src/frontends/paddle/src/op/p_norm.cpp
Normal file
63
src/frontends/paddle/src/op/p_norm.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs p_norm(const NodeContext& node) {
|
||||
auto data = node.get_input("X");
|
||||
const auto p = node.get_attribute<float>("porder", 2.0);
|
||||
const auto axis = node.get_attribute<int32_t>("axis", -1);
|
||||
const auto keepdim = node.get_attribute<bool>("keepdim", false);
|
||||
|
||||
const auto absNode = std::make_shared<default_opset::Abs>(data);
|
||||
const auto axisNode = default_opset::Constant::create(ov::element::i32, {1}, {axis});
|
||||
|
||||
if (p == std::numeric_limits<float>::infinity()) {
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<default_opset::ReduceMax>(absNode, axisNode, keepdim)},
|
||||
{"Out"});
|
||||
} else if (p == -std::numeric_limits<float>::infinity()) {
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<default_opset::ReduceMin>(absNode, axisNode, keepdim)},
|
||||
{"Out"});
|
||||
} else if (p == 0.0) {
|
||||
const auto input_dtype = data.get_element_type();
|
||||
const auto zero = default_opset::Constant::create(input_dtype, {1}, {0});
|
||||
const auto non_zero = std::make_shared<default_opset::NotEqual>(absNode, zero);
|
||||
const auto converted_non_zero = std::make_shared<default_opset::Convert>(non_zero, input_dtype);
|
||||
|
||||
const auto reduce_sum = std::make_shared<default_opset::ReduceSum>(converted_non_zero, axisNode, keepdim);
|
||||
const auto input_shape = data.get_partial_shape();
|
||||
// process 1-d input and keepdim=false, output shape is [1], instead of scalar.
|
||||
if (!keepdim) {
|
||||
PADDLE_OP_CHECK(node,
|
||||
input_shape.rank().is_static(),
|
||||
"input rank of p_norm must be static when keepdim=false and p=0.");
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
if (input_rank == 1) {
|
||||
const auto one = default_opset::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto out = std::make_shared<default_opset::Reshape>(reduce_sum, one, false);
|
||||
return node.default_single_output_mapping({out}, {"Out"});
|
||||
}
|
||||
}
|
||||
return node.default_single_output_mapping({reduce_sum}, {"Out"});
|
||||
} else {
|
||||
const auto power_factor = default_opset::Constant::create(ov::element::f32, Shape{1}, {p});
|
||||
const auto powNode = std::make_shared<default_opset::Power>(absNode, power_factor);
|
||||
const auto reduce_sum = std::make_shared<default_opset::ReduceSum>(powNode, axisNode, keepdim);
|
||||
const auto extract_factor = default_opset::Constant::create(ov::element::f32, Shape{1}, {1.0 / p});
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::Power>(reduce_sum, extract_factor)},
|
||||
{"Out"});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -59,6 +59,7 @@ OP_CONVERTER(matrix_nms);
|
||||
OP_CONVERTER(meshgrid);
|
||||
OP_CONVERTER(multiclass_nms);
|
||||
OP_CONVERTER(nearest_interp_v2);
|
||||
OP_CONVERTER(p_norm);
|
||||
OP_CONVERTER(pad3d);
|
||||
OP_CONVERTER(pow);
|
||||
OP_CONVERTER(pool2d);
|
||||
@ -155,6 +156,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"multiclass_nms3", op::multiclass_nms},
|
||||
{"nearest_interp_v2", op::nearest_interp_v2},
|
||||
{"nearest_interp", op::nearest_interp_v2},
|
||||
{"p_norm", op::p_norm},
|
||||
{"pad3d", op::pad3d},
|
||||
{"pow", op::pow},
|
||||
{"pool2d", op::pool2d},
|
||||
|
Loading…
Reference in New Issue
Block a user