[PDPD Frontend] enable bicubic, trilinear, linear (#7731)
This commit is contained in:
parent
88cab67833
commit
602fb74fe4
@ -2,71 +2,72 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
std::shared_ptr<ngraph::Node> calculate_output_shape_based_on_scales(const Output<ngraph::Node>& data,
|
||||
const std::vector<float>& scale,
|
||||
Output<ngraph::Node>& scales) {
|
||||
FRONT_END_GENERAL_CHECK(scale.size() > 0);
|
||||
if (scale.size() == 1)
|
||||
scales = opset6::Constant::create<float>(element::f32, Shape{4}, {1, 1, scale[0], scale[0]});
|
||||
else if (scale.size() == 2)
|
||||
scales = opset6::Constant::create<float>(element::f32, Shape{4}, {1, 1, scale[0], scale[1]});
|
||||
else if (scale.size() == 3)
|
||||
scales = opset6::Constant::create<float>(element::f32, Shape{4}, {1, scale[0], scale[1], scale[2]});
|
||||
else
|
||||
scales = opset6::Constant::create<float>(element::f32,
|
||||
Shape{scale.size()},
|
||||
std::vector<float>(scale.begin(), scale.end()));
|
||||
const auto shape_of_data =
|
||||
std::make_shared<opset6::Convert>(std::make_shared<opset6::ShapeOf>(data), scales.get_element_type());
|
||||
const auto multiply = std::make_shared<opset6::Multiply>(shape_of_data, scales);
|
||||
const auto output_shape = std::make_shared<opset6::Convert>(multiply, ngraph::element::i64);
|
||||
using namespace default_opset;
|
||||
|
||||
static std::shared_ptr<ngraph::Node> calculate_output_shape_based_on_scales(const Output<ngraph::Node>& data,
|
||||
const std::vector<float>& scale,
|
||||
Output<ngraph::Node>& scales,
|
||||
const int space_dim) {
|
||||
const size_t scale_size = static_cast<size_t>(space_dim + 2);
|
||||
FRONT_END_GENERAL_CHECK(scale.size() > 0 && scale.size() <= scale_size);
|
||||
|
||||
std::vector<float> full_scales(scale_size, 1.0f);
|
||||
std::memcpy(&full_scales[scale_size - scale.size()], &scale[0], scale.size() * sizeof(float));
|
||||
scales = Constant::create<float>(element::f32, {scale_size}, full_scales);
|
||||
|
||||
const auto shape_of_data = std::make_shared<Convert>(std::make_shared<ShapeOf>(data), scales.get_element_type());
|
||||
const auto multiply = std::make_shared<Multiply>(shape_of_data, scales);
|
||||
const auto output_shape = std::make_shared<Convert>(multiply, ngraph::element::i64);
|
||||
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> calculate_scales_based_on_sizes(const Output<ngraph::Node>& data,
|
||||
const Output<ngraph::Node>& sizes) {
|
||||
static std::shared_ptr<ngraph::Node> calculate_scales_based_on_sizes(const Output<ngraph::Node>& data,
|
||||
const Output<ngraph::Node>& sizes) {
|
||||
const float epsilon = 1.0e-5;
|
||||
const auto shape_of_data =
|
||||
std::make_shared<opset6::Convert>(std::make_shared<opset6::ShapeOf>(data), ngraph::element::f32);
|
||||
const auto converted_sizes = std::make_shared<opset6::Convert>(sizes, ngraph::element::f32);
|
||||
const auto divide = std::make_shared<opset6::Divide>(converted_sizes, shape_of_data);
|
||||
const auto eps_node = std::make_shared<opset6::Constant>(ngraph::element::f32, Shape{}, epsilon);
|
||||
const auto scales = std::make_shared<opset6::Add>(divide, eps_node);
|
||||
const auto shape_of_data = std::make_shared<Convert>(std::make_shared<ShapeOf>(data), ngraph::element::f32);
|
||||
const auto converted_sizes = std::make_shared<Convert>(sizes, ngraph::element::f32);
|
||||
const auto divide = std::make_shared<Divide>(converted_sizes, shape_of_data);
|
||||
const auto eps_node = std::make_shared<Constant>(ngraph::element::f32, Shape{}, epsilon);
|
||||
const auto scales = std::make_shared<Add>(divide, eps_node);
|
||||
|
||||
return scales;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> extract_out_sizes(const Output<ngraph::Node>& data,
|
||||
const std::vector<int64_t>& out_sizes) {
|
||||
const auto shape_of_x = std::make_shared<opset6::ShapeOf>(data);
|
||||
auto shape_begin = opset6::Constant::create(element::i64, {1}, {0});
|
||||
auto shape_end = opset6::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto nc_node = std::make_shared<opset6::StridedSlice>(shape_of_x,
|
||||
shape_begin,
|
||||
shape_end,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
auto hw_node = opset6::Constant::create<int64_t>(element::i64, Shape{2}, out_sizes);
|
||||
return std::make_shared<opset6::Concat>(OutputVector{nc_node, hw_node}, 0);
|
||||
static std::shared_ptr<ngraph::Node> extract_out_sizes(const Output<ngraph::Node>& data,
|
||||
const std::vector<int64_t>& out_sizes) {
|
||||
const auto shape_of_x = std::make_shared<ShapeOf>(data);
|
||||
const auto shape_begin = Constant::create(element::i64, {1}, {0});
|
||||
const int end_idx = static_cast<int>(out_sizes.size());
|
||||
const auto shape_end = Constant::create(element::i64, Shape{1}, {-end_idx});
|
||||
const auto nc_node = std::make_shared<StridedSlice>(shape_of_x,
|
||||
shape_begin,
|
||||
shape_end,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
const auto hw_node = Constant::create<int64_t>(element::i64, Shape{out_sizes.size()}, out_sizes);
|
||||
return std::make_shared<Concat>(OutputVector{nc_node, hw_node}, 0);
|
||||
}
|
||||
|
||||
// TODO support different data_layout #55170
|
||||
|
||||
NamedOutputs interpolate(const NodeContext& node, const ngraph::opset6::Interpolate::InterpolateMode& mode) {
|
||||
auto x = node.get_ng_input("X");
|
||||
using InterpolateMode = ngraph::opset6::Interpolate::InterpolateMode;
|
||||
using CoordinateTransformMode = ngraph::opset6::Interpolate::CoordinateTransformMode;
|
||||
using Nearest_mode = ngraph::opset6::Interpolate::NearestMode;
|
||||
using InterpolateAttrs = ngraph::opset6::Interpolate::InterpolateAttrs;
|
||||
using ShapeCalcMode = ngraph::opset6::Interpolate::ShapeCalcMode;
|
||||
static NamedOutputs interpolate(const NodeContext& node,
|
||||
const Interpolate::InterpolateMode& mode,
|
||||
const int space_dim) {
|
||||
const auto x = node.get_ng_input("X");
|
||||
using InterpolateMode = Interpolate::InterpolateMode;
|
||||
using CoordinateTransformMode = Interpolate::CoordinateTransformMode;
|
||||
using Nearest_mode = Interpolate::NearestMode;
|
||||
using InterpolateAttrs = Interpolate::InterpolateAttrs;
|
||||
using ShapeCalcMode = Interpolate::ShapeCalcMode;
|
||||
|
||||
InterpolateAttrs attrs;
|
||||
|
||||
@ -74,45 +75,67 @@ NamedOutputs interpolate(const NodeContext& node, const ngraph::opset6::Interpol
|
||||
|
||||
auto out_w = node.get_attribute<int>("out_w");
|
||||
auto out_h = node.get_attribute<int>("out_h");
|
||||
auto out_d = node.get_attribute<int>("out_d");
|
||||
auto scale = node.get_attribute<std::vector<float>>("scale");
|
||||
Output<Node> scales;
|
||||
Output<Node> target_spatial_shape;
|
||||
bool out_flag = out_w <= 0;
|
||||
if (space_dim == 2) {
|
||||
out_flag |= out_h <= 0;
|
||||
} else if (space_dim == 3) {
|
||||
out_flag |= out_h <= 0 || out_d <= 0;
|
||||
}
|
||||
|
||||
if (node.has_ng_input("OutSize")) {
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::SIZES;
|
||||
auto hw_shape = node.get_ng_input("OutSize");
|
||||
const auto shape_of_x = std::make_shared<opset6::ShapeOf>(x);
|
||||
auto shape_begin = opset6::Constant::create(element::i64, {1}, {0});
|
||||
auto shape_end = opset6::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto nc_node = std::make_shared<opset6::StridedSlice>(shape_of_x,
|
||||
shape_begin,
|
||||
shape_end,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
target_spatial_shape = std::make_shared<opset6::Concat>(
|
||||
OutputVector{nc_node, std::make_shared<opset6::Convert>(hw_shape, element::i64)},
|
||||
0);
|
||||
const auto hw_shape = node.get_ng_input("OutSize");
|
||||
const auto shape_of_x = std::make_shared<ShapeOf>(x);
|
||||
const auto shape_begin = Constant::create(element::i64, {1}, {0});
|
||||
const auto shape_end = Constant::create(element::i64, Shape{1}, {-space_dim});
|
||||
const auto nc_node = std::make_shared<StridedSlice>(shape_of_x,
|
||||
shape_begin,
|
||||
shape_end,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
target_spatial_shape =
|
||||
std::make_shared<Concat>(OutputVector{nc_node, std::make_shared<Convert>(hw_shape, element::i64)}, 0);
|
||||
scales = calculate_scales_based_on_sizes(x, target_spatial_shape);
|
||||
} else if (out_w <= 0 || out_h <= 0) {
|
||||
} else if (out_flag) {
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::SCALES;
|
||||
target_spatial_shape = calculate_output_shape_based_on_scales(x, scale, scales);
|
||||
target_spatial_shape = calculate_output_shape_based_on_scales(x, scale, scales, space_dim);
|
||||
} else {
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::SIZES;
|
||||
target_spatial_shape = extract_out_sizes(x, {out_h, out_w});
|
||||
std::vector<int64_t> sizes;
|
||||
if (space_dim == 1)
|
||||
sizes = {out_w};
|
||||
else if (space_dim == 2)
|
||||
sizes = {out_h, out_w};
|
||||
else
|
||||
sizes = {out_d, out_h, out_w};
|
||||
|
||||
target_spatial_shape = extract_out_sizes(x, sizes);
|
||||
scales = calculate_scales_based_on_sizes(x, target_spatial_shape);
|
||||
}
|
||||
|
||||
bool align_corners = node.get_attribute<bool>("align_corners");
|
||||
int32_t align_mode = node.get_attribute<int32_t>("align_mode");
|
||||
const bool align_corners = node.get_attribute<bool>("align_corners");
|
||||
const int32_t align_mode = node.get_attribute<int32_t>("align_mode");
|
||||
|
||||
if (mode == InterpolateMode::NEAREST) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ASYMMETRIC;
|
||||
} else if (!align_corners && align_mode == 1) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ASYMMETRIC;
|
||||
} else if (!align_corners && align_mode == 0) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::HALF_PIXEL;
|
||||
} else if (align_corners) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ALIGN_CORNERS;
|
||||
} else if (mode == InterpolateMode::CUBIC) {
|
||||
if (!align_corners) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::HALF_PIXEL;
|
||||
} else {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ALIGN_CORNERS;
|
||||
}
|
||||
} else {
|
||||
if (!align_corners && align_mode == 1) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ASYMMETRIC;
|
||||
} else if (!align_corners && align_mode == 0) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::HALF_PIXEL;
|
||||
} else if (align_corners) {
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::ALIGN_CORNERS;
|
||||
}
|
||||
}
|
||||
|
||||
attrs.nearest_mode = Nearest_mode::SIMPLE;
|
||||
@ -120,19 +143,33 @@ NamedOutputs interpolate(const NodeContext& node, const ngraph::opset6::Interpol
|
||||
attrs.pads_begin = {0, 0, 0, 0};
|
||||
attrs.pads_end = {0, 0, 0, 0};
|
||||
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<ngraph::opset6::Interpolate>(x, target_spatial_shape, scales, attrs)},
|
||||
{"Out"});
|
||||
return node.default_single_output_mapping({std::make_shared<Interpolate>(x, target_spatial_shape, scales, attrs)},
|
||||
{"Out"});
|
||||
}
|
||||
|
||||
NamedOutputs linear_interp_v2(const NodeContext& node) {
|
||||
const auto mode = Interpolate::InterpolateMode::LINEAR_ONNX;
|
||||
return interpolate(node, mode, 1);
|
||||
}
|
||||
|
||||
NamedOutputs bilinear_interp_v2(const NodeContext& node) {
|
||||
auto mode = ngraph::opset6::Interpolate::InterpolateMode::LINEAR_ONNX;
|
||||
return interpolate(node, mode);
|
||||
const auto mode = Interpolate::InterpolateMode::LINEAR_ONNX;
|
||||
return interpolate(node, mode, 2);
|
||||
}
|
||||
|
||||
NamedOutputs trilinear_interp_v2(const NodeContext& node) {
|
||||
const auto mode = Interpolate::InterpolateMode::LINEAR_ONNX;
|
||||
return interpolate(node, mode, 3);
|
||||
}
|
||||
|
||||
NamedOutputs nearest_interp_v2(const NodeContext& node) {
|
||||
auto mode = ngraph::opset6::Interpolate::InterpolateMode::NEAREST;
|
||||
return interpolate(node, mode);
|
||||
const auto mode = Interpolate::InterpolateMode::NEAREST;
|
||||
return interpolate(node, mode, 2);
|
||||
}
|
||||
|
||||
NamedOutputs bicubic_interp_v2(const NodeContext& node) {
|
||||
const auto mode = Interpolate::InterpolateMode::CUBIC;
|
||||
return interpolate(node, mode, 2);
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -11,6 +11,7 @@ namespace op {
|
||||
OP_CONVERTER(argmax);
|
||||
OP_CONVERTER(assign_value);
|
||||
OP_CONVERTER(batch_norm);
|
||||
OP_CONVERTER(bicubic_interp_v2);
|
||||
OP_CONVERTER(bilinear_interp_v2);
|
||||
OP_CONVERTER(cast);
|
||||
OP_CONVERTER(clip);
|
||||
@ -41,6 +42,7 @@ OP_CONVERTER(hard_sigmoid);
|
||||
OP_CONVERTER(hard_swish);
|
||||
OP_CONVERTER(layer_norm);
|
||||
OP_CONVERTER(leaky_relu);
|
||||
OP_CONVERTER(linear_interp_v2);
|
||||
OP_CONVERTER(log);
|
||||
OP_CONVERTER(logical_not);
|
||||
OP_CONVERTER(matmul);
|
||||
@ -68,6 +70,7 @@ OP_CONVERTER(squeeze);
|
||||
OP_CONVERTER(stack);
|
||||
OP_CONVERTER(tanh);
|
||||
OP_CONVERTER(transpose2);
|
||||
OP_CONVERTER(trilinear_interp_v2);
|
||||
OP_CONVERTER(unsqueeze);
|
||||
OP_CONVERTER(yolo_box);
|
||||
} // namespace op
|
||||
@ -82,6 +85,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {{"arg_max", op::argmax},
|
||||
{"assign_value", op::assign_value},
|
||||
{"batch_norm", op::batch_norm},
|
||||
{"bicubic_interp_v2", op::bicubic_interp_v2},
|
||||
{"bilinear_interp_v2", op::bilinear_interp_v2},
|
||||
{"bilinear_interp", op::bilinear_interp_v2},
|
||||
{"bmm", op::matmul},
|
||||
@ -116,6 +120,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"hard_swish", op::hard_swish},
|
||||
{"layer_norm", op::layer_norm},
|
||||
{"leaky_relu", op::leaky_relu},
|
||||
{"linear_interp_v2", op::linear_interp_v2},
|
||||
{"log", op::log},
|
||||
{"logical_not", op::logical_not},
|
||||
{"lookup_table_v2", op::embedding},
|
||||
@ -147,6 +152,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"sync_batch_norm", op::batch_norm},
|
||||
{"tanh", op::tanh},
|
||||
{"transpose2", op::transpose2},
|
||||
{"trilinear_interp_v2", op::trilinear_interp_v2},
|
||||
{"unsqueeze2", op::unsqueeze},
|
||||
{"yolo_box", op::yolo_box}};
|
||||
};
|
||||
|
@ -41,6 +41,14 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("avgPool_test9"),
|
||||
std::string("batch_norm_nchw"),
|
||||
std::string("batch_norm_nhwc"),
|
||||
std::string("bicubic_downsample_false_0"),
|
||||
std::string("bicubic_downsample_false_1"),
|
||||
std::string("bicubic_downsample_true_0"),
|
||||
std::string("bicubic_upsample_false_0"),
|
||||
std::string("bicubic_upsample_false_1"),
|
||||
std::string("bicubic_upsample_scales"),
|
||||
std::string("bicubic_upsample_scales2"),
|
||||
std::string("bicubic_upsample_true_0"),
|
||||
std::string("bilinear_downsample_false_0"),
|
||||
std::string("bilinear_downsample_false_1"),
|
||||
std::string("bilinear_downsample_true_0"),
|
||||
@ -118,6 +126,14 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("layer_norm_noscale"),
|
||||
std::string("layer_norm_noshift"),
|
||||
std::string("leaky_relu"),
|
||||
std::string("linear_downsample_false_0"),
|
||||
std::string("linear_downsample_false_1"),
|
||||
std::string("linear_downsample_true_0"),
|
||||
std::string("linear_upsample_false_0"),
|
||||
std::string("linear_upsample_false_1"),
|
||||
std::string("linear_upsample_scales"),
|
||||
std::string("linear_upsample_scales2"),
|
||||
std::string("linear_upsample_true_0"),
|
||||
std::string("log"),
|
||||
std::string("logical_not"),
|
||||
std::string("matmul_xt"),
|
||||
@ -204,6 +220,14 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("stack_test_neg_axis"),
|
||||
std::string("stack_test_none_axis"),
|
||||
std::string("tanh"),
|
||||
std::string("trilinear_downsample_false_0"),
|
||||
std::string("trilinear_downsample_false_1"),
|
||||
std::string("trilinear_downsample_true_0"),
|
||||
std::string("trilinear_upsample_false_0"),
|
||||
std::string("trilinear_upsample_false_1"),
|
||||
std::string("trilinear_upsample_scales"),
|
||||
std::string("trilinear_upsample_scales2"),
|
||||
std::string("trilinear_upsample_true_0"),
|
||||
std::string("unsqueeze"),
|
||||
// Temporily disable them until root caused to secure CI stable.
|
||||
// CVS-66703 to track this.
|
||||
|
@ -177,12 +177,264 @@ def bilinear_upsample_scales():
|
||||
pdpd_result = pdpd_interpolate(data, None, 2, mode='bilinear', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCHW', name=test['name'])
|
||||
|
||||
# trilinear
|
||||
def resize_upsample_trilinear():
|
||||
data = np.array([[[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
],[
|
||||
[13, 14, 15, 16],
|
||||
[9, 10, 11, 12],
|
||||
[5, 6, 7, 8],
|
||||
[1, 2, 3, 4],
|
||||
]]]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'trilinear_upsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'trilinear_upsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'trilinear_upsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, [4, 64, 64], None, mode='TRILINEAR', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCDHW', name=test['name'])
|
||||
|
||||
|
||||
def resize_downsample_trilinear():
|
||||
data = np.array([[[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
],[
|
||||
[13, 14, 15, 16],
|
||||
[9, 10, 11, 12],
|
||||
[5, 6, 7, 8],
|
||||
[1, 2, 3, 4]
|
||||
]]]], dtype=np.float32)
|
||||
data_28 = data.reshape([1, 1, 2, 2, 8])
|
||||
test_case = [{'name': 'trilinear_downsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'trilinear_downsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'trilinear_downsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data_28, [2, 2, 4], None, mode='TRILINEAR', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCDHW', name=test['name'])
|
||||
|
||||
def trilinear_upsample_tensor_size():
|
||||
data = np.array([[[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
]]]], dtype=np.float32)
|
||||
sizes = np.array([2, 8, 8], dtype="int32")
|
||||
|
||||
test_case = [{'name': 'trilinear_upsample_tensor_size', 'align_corners': False, 'align_mode': 1}]
|
||||
|
||||
for test in test_case:
|
||||
main_program = pdpd.static.Program()
|
||||
startup_program = pdpd.static.Program()
|
||||
with pdpd.static.program_guard(main_program, startup_program):
|
||||
node_x = pdpd.static.data(name='x', shape=data.shape, dtype='float32')
|
||||
node_sizes = pdpd.static.data(name='sizes', shape=sizes.shape, dtype='int32')
|
||||
interp = interpolate(node_x, size=node_sizes, scale_factor=None,
|
||||
mode='TRILINEAR', align_corners=test['align_corners'], align_mode=test['align_mode'],
|
||||
data_format='NCDHW', name=test['name'])
|
||||
out = pdpd.static.nn.batch_norm(interp, use_global_stats=True, epsilon=0)
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
exe.run(startup_program)
|
||||
outs = exe.run(
|
||||
feed={'x': data, 'sizes': sizes},
|
||||
fetch_list=out,
|
||||
program=main_program)
|
||||
saveModel(test['name'], exe, feedkeys=['x', 'sizes'], fetchlist=out, inputs=[data, sizes], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
def trilinear_upsample_scales():
|
||||
data = np.array([[[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
]]]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'trilinear_upsample_scales', 'align_corners': False, 'align_mode': 1, "scales": 2},
|
||||
{'name': 'trilinear_upsample_scales2', 'align_corners': False, 'align_mode': 1, "scales": [1, 2, 2]}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, None, 3, mode='TRILINEAR', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCDHW', name=test['name'])
|
||||
|
||||
|
||||
# bicubic
|
||||
def resize_upsample_bicubic():
|
||||
data = np.array([[[
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]
|
||||
]]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'bicubic_upsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'bicubic_upsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'bicubic_upsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, [6, 6], None, mode='bicubic', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCHW', name=test['name'])
|
||||
|
||||
|
||||
def resize_downsample_bicubic():
|
||||
data = np.array([[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
]]], dtype=np.float32)
|
||||
data_28 = data.reshape([1, 1, 2, 8])
|
||||
test_case = [{'name': 'bicubic_downsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'bicubic_downsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'bicubic_downsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data_28, [2, 4], None, mode='bicubic', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCHW', name=test['name'])
|
||||
|
||||
def bicubic_upsample_tensor_size():
|
||||
data = np.array([[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
]]], dtype=np.float32)
|
||||
sizes = np.array([8, 8], dtype="int32")
|
||||
|
||||
test_case = [{'name': 'bicubic_upsample_tensor_size', 'align_corners': False, 'align_mode': 1}]
|
||||
|
||||
for test in test_case:
|
||||
main_program = pdpd.static.Program()
|
||||
startup_program = pdpd.static.Program()
|
||||
with pdpd.static.program_guard(main_program, startup_program):
|
||||
node_x = pdpd.static.data(name='x', shape=data.shape, dtype='float32')
|
||||
node_sizes = pdpd.static.data(name='sizes', shape=sizes.shape, dtype='int32')
|
||||
interp = interpolate(node_x, size=node_sizes, scale_factor=None,
|
||||
mode='bicubic', align_corners=test['align_corners'], align_mode=test['align_mode'],
|
||||
data_format='NCHW', name=test['name'])
|
||||
out = pdpd.static.nn.batch_norm(interp, use_global_stats=True, epsilon=0)
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
exe.run(startup_program)
|
||||
outs = exe.run(
|
||||
feed={'x': data, 'sizes': sizes},
|
||||
fetch_list=out,
|
||||
program=main_program)
|
||||
saveModel(test['name'], exe, feedkeys=['x', 'sizes'], fetchlist=out, inputs=[data, sizes], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
def bicubic_upsample_scales():
|
||||
data = np.array([[[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]
|
||||
]]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'bicubic_upsample_scales', 'align_corners': False, 'align_mode': 1, "scales": 2},
|
||||
{'name': 'bicubic_upsample_scales2', 'align_corners': False, 'align_mode': 1, "scales": [2, 2]}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, None, 2, mode='bicubic', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCHW', name=test['name'])
|
||||
|
||||
# linear
|
||||
def resize_upsample_linear():
|
||||
data = np.array([[
|
||||
[1, 2, 3]
|
||||
]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'linear_upsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'linear_upsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'linear_upsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, [6,], None, mode='linear', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCW', name=test['name'])
|
||||
|
||||
|
||||
def resize_downsample_linear():
|
||||
data = np.array([[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8]
|
||||
]], dtype=np.float32)
|
||||
data_28 = data.reshape([1, 1, 8])
|
||||
test_case = [{'name': 'linear_downsample_false_1', 'align_corners': False, 'align_mode': 1},
|
||||
{'name': 'linear_downsample_false_0', 'align_corners': False, 'align_mode': 0},
|
||||
{'name': 'linear_downsample_true_0', 'align_corners': True, 'align_mode': 0}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data_28, [4,], None, mode='linear', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCW', name=test['name'])
|
||||
|
||||
def linear_upsample_tensor_size():
|
||||
data = np.array([[
|
||||
[1, 2, 3, 4]
|
||||
]], dtype=np.float32)
|
||||
sizes = np.array([8,], dtype="int32")
|
||||
|
||||
test_case = [{'name': 'linear_upsample_tensor_size', 'align_corners': False, 'align_mode': 1}]
|
||||
|
||||
for test in test_case:
|
||||
main_program = pdpd.static.Program()
|
||||
startup_program = pdpd.static.Program()
|
||||
with pdpd.static.program_guard(main_program, startup_program):
|
||||
node_x = pdpd.static.data(name='x', shape=data.shape, dtype='float32')
|
||||
node_sizes = pdpd.static.data(name='sizes', shape=sizes.shape, dtype='int32')
|
||||
interp = interpolate(node_x, size=node_sizes, scale_factor=None,
|
||||
mode='linear', align_corners=test['align_corners'], align_mode=test['align_mode'],
|
||||
data_format='NCW', name=test['name'])
|
||||
out = pdpd.static.nn.batch_norm(interp, use_global_stats=True, epsilon=0)
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
exe.run(startup_program)
|
||||
outs = exe.run(
|
||||
feed={'x': data, 'sizes': sizes},
|
||||
fetch_list=out,
|
||||
program=main_program)
|
||||
saveModel(test['name'], exe, feedkeys=['x', 'sizes'], fetchlist=out, inputs=[data, sizes], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
def linear_upsample_scales():
|
||||
data = np.array([[
|
||||
[1, 2, 3, 4]
|
||||
]], dtype=np.float32)
|
||||
|
||||
test_case = [{'name': 'linear_upsample_scales', 'align_corners': False, 'align_mode': 1, "scales": 2},
|
||||
{'name': 'linear_upsample_scales2', 'align_corners': False, 'align_mode': 1, "scales": [2, 2]}]
|
||||
|
||||
for test in test_case:
|
||||
pdpd_result = pdpd_interpolate(data, None, 2, mode='linear', align_corners=test['align_corners'],
|
||||
align_mode=test['align_mode'], data_format='NCW', name=test['name'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
# bilinear
|
||||
resize_downsample_bilinear()
|
||||
resize_upsample_bilinear()
|
||||
bilinear_upsample_tensor_size()
|
||||
bilinear_upsample_scales()
|
||||
# nearest
|
||||
resize_downsample_nearest()
|
||||
resize_upsample_nearest()
|
||||
nearest_upsample_tensor_size()
|
||||
bilinear_upsample_tensor_size()
|
||||
bilinear_upsample_scales()
|
||||
# trilinear
|
||||
resize_downsample_trilinear()
|
||||
resize_upsample_trilinear()
|
||||
trilinear_upsample_tensor_size()
|
||||
trilinear_upsample_scales()
|
||||
# bicubic
|
||||
resize_downsample_bicubic()
|
||||
resize_upsample_bicubic()
|
||||
bicubic_upsample_tensor_size()
|
||||
bicubic_upsample_scales()
|
||||
# linear
|
||||
resize_downsample_linear()
|
||||
resize_upsample_linear()
|
||||
linear_upsample_tensor_size()
|
||||
linear_upsample_scales()
|
Loading…
Reference in New Issue
Block a user