Fix handling of the QuantizeLinear op connected to multiout ops (#4741)

This commit is contained in:
Tomasz Dołbniak 2021-03-12 11:01:37 +01:00 committed by GitHub
parent 39ae56d71f
commit 6ac849eed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,11 +39,11 @@ namespace ngraph
{
namespace
{
std::shared_ptr<ngraph::Node> get_zero_point(const OutputVector& inputs)
Output<ngraph::Node> get_zero_point(const OutputVector& inputs)
{
if (inputs.size() > 2)
{
return inputs.at(2).get_node_shared_ptr();
return inputs.at(2);
}
else
{
@ -53,9 +53,9 @@ namespace ngraph
}
void validate_zero_point_type(const Node& onnx_node,
const std::shared_ptr<ngraph::Node>& y_zero_point)
const Output<ngraph::Node>& y_zero_point)
{
const auto& y_zero_point_et = y_zero_point->get_element_type();
const auto& y_zero_point_et = y_zero_point.get_element_type();
CHECK_VALID_NODE(
onnx_node,
y_zero_point_et.is_static() &&
@ -64,11 +64,10 @@ namespace ngraph
"integer type.");
}
std::shared_ptr<ngraph::Node>
validate_scale(const Node& onnx_node,
const std::shared_ptr<ngraph::Node>& y_scale)
Output<ngraph::Node> validate_scale(const Node& onnx_node,
const Output<ngraph::Node>& y_scale)
{
const auto& y_scale_et = y_scale->get_element_type();
const auto& y_scale_et = y_scale.get_element_type();
CHECK_VALID_NODE(onnx_node,
y_scale_et.is_static(),
"\"y_scale\" input data type must be static.");
@ -79,10 +78,10 @@ namespace ngraph
return y_scale;
}
std::shared_ptr<ngraph::Node> validate_data(const Node& onnx_node,
std::shared_ptr<ngraph::Node>& data)
Output<ngraph::Node> validate_data(const Node& onnx_node,
const Output<ngraph::Node>& data)
{
const auto& data_et = data->get_element_type();
const auto& data_et = data.get_element_type();
CHECK_VALID_NODE(onnx_node,
data_et.is_static(),
"\"x\" input data type must be static.");
@ -120,8 +119,8 @@ namespace ngraph
}
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
get_input_bands(const std::shared_ptr<ngraph::Node>& y_scale,
const std::shared_ptr<ngraph::Node>& y_zero_point,
get_input_bands(const Output<ngraph::Node>& y_scale,
const Output<ngraph::Node>& y_zero_point,
const std::shared_ptr<ngraph::Node>& output_low,
const std::shared_ptr<ngraph::Node>& output_high,
const element::Type& data_type)
@ -142,12 +141,12 @@ namespace ngraph
}
std::shared_ptr<ngraph::Node>
make_fake_quantize(const std::shared_ptr<ngraph::Node>& y_scale,
const std::shared_ptr<ngraph::Node>& y_zero_point,
const std::shared_ptr<ngraph::Node>& data)
make_fake_quantize(const Output<ngraph::Node>& y_scale,
const Output<ngraph::Node>& y_zero_point,
const Output<ngraph::Node>& data)
{
const element::Type& destination_type = y_zero_point->get_element_type();
const element::Type& data_type = data->get_element_type();
const element::Type& destination_type = y_zero_point.get_element_type();
const element::Type& data_type = data.get_element_type();
std::shared_ptr<ngraph::Node> output_low;
std::shared_ptr<ngraph::Node> output_high;
@ -166,16 +165,16 @@ namespace ngraph
data, input_low, input_high, output_low, output_high, levels),
destination_type);
}
}
}
} // namespace
} // namespace detail
namespace set_1
{
OutputVector quantize_linear(const Node& node)
{
OutputVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0).get_node_shared_ptr();
auto y_scale = inputs.at(1).get_node_shared_ptr();
auto x = inputs.at(0);
auto y_scale = inputs.at(1);
auto y_zero_point = detail::get_zero_point(inputs);
x = detail::validate_data(node, x);
@ -191,21 +190,21 @@ namespace ngraph
OutputVector quantize_linear(const Node& node)
{
OutputVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0).get_node_shared_ptr();
auto y_scale = inputs.at(1).get_node_shared_ptr();
auto x = inputs.at(0);
auto y_scale = inputs.at(1);
auto y_zero_point = detail::get_zero_point(inputs);
x = detail::validate_data(node, x);
detail::validate_zero_point_type(node, y_zero_point);
y_scale = detail::validate_scale(node, y_scale);
const auto& x_shape = x->get_output_partial_shape(0);
const auto& x_shape = x.get_partial_shape();
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
axis = normalize_axis(node.get_description(), axis, x_shape.rank());
const auto& y_scale_shape = y_scale->get_output_partial_shape(0);
const auto& y_zero_point_shape = y_zero_point->get_output_partial_shape(0);
const auto& y_scale_shape = y_scale.get_partial_shape();
const auto& y_zero_point_shape = y_zero_point.get_partial_shape();
if (y_scale_shape.rank().is_static() &&
y_scale_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&