Fixed performance drop for Interpolate-4 (#4354)
* Commit. * Reverted fix in the nGraph conversion of Intepolate-1 into Interpolate-4. * Small fix. * Added comment. * Added TODO.
This commit is contained in:
parent
24aeb16fd1
commit
3b2506989e
@ -36,6 +36,7 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
|
|||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto input_shape_rank = inp_partial_shape.rank().get_length();
|
||||||
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
|
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
|
||||||
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
|
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
|
||||||
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
|
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
|
||||||
@ -45,7 +46,20 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
|
|||||||
if (attrsV0.mode == "nearest") {
|
if (attrsV0.mode == "nearest") {
|
||||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::nearest;
|
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::nearest;
|
||||||
} else if (attrsV0.mode == "linear") {
|
} else if (attrsV0.mode == "linear") {
|
||||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::linear;
|
// If we write only
|
||||||
|
// attrsV4.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear;
|
||||||
|
// instead of a conditional statements below when attrsV0.mode == "linear",
|
||||||
|
// then we have a performance drop, because CPU and GPU have no optimized
|
||||||
|
// version of the 'linear' mode.
|
||||||
|
// TODO: delete this conditional statement, when CPU and GPU will have
|
||||||
|
// optimized version of the 'linear' mode.
|
||||||
|
if (input_shape_rank < 5) {
|
||||||
|
attrsV4.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
|
||||||
|
} else if (input_shape_rank == 5) {
|
||||||
|
attrsV4.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} else if (attrsV0.mode == "cubic") {
|
} else if (attrsV0.mode == "cubic") {
|
||||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::cubic;
|
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::cubic;
|
||||||
} else if (attrsV0.mode == "linear_onnx") {
|
} else if (attrsV0.mode == "linear_onnx") {
|
||||||
|
Loading…
Reference in New Issue
Block a user