Fix handling of the QuantizeLinear op connected to multiout ops (#4741)
This commit is contained in:
parent
39ae56d71f
commit
6ac849eed3
@ -39,11 +39,11 @@ namespace ngraph
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> get_zero_point(const OutputVector& inputs)
|
||||
Output<ngraph::Node> get_zero_point(const OutputVector& inputs)
|
||||
{
|
||||
if (inputs.size() > 2)
|
||||
{
|
||||
return inputs.at(2).get_node_shared_ptr();
|
||||
return inputs.at(2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -53,9 +53,9 @@ namespace ngraph
|
||||
}
|
||||
|
||||
void validate_zero_point_type(const Node& onnx_node,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point)
|
||||
const Output<ngraph::Node>& y_zero_point)
|
||||
{
|
||||
const auto& y_zero_point_et = y_zero_point->get_element_type();
|
||||
const auto& y_zero_point_et = y_zero_point.get_element_type();
|
||||
CHECK_VALID_NODE(
|
||||
onnx_node,
|
||||
y_zero_point_et.is_static() &&
|
||||
@ -64,11 +64,10 @@ namespace ngraph
|
||||
"integer type.");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node>
|
||||
validate_scale(const Node& onnx_node,
|
||||
const std::shared_ptr<ngraph::Node>& y_scale)
|
||||
Output<ngraph::Node> validate_scale(const Node& onnx_node,
|
||||
const Output<ngraph::Node>& y_scale)
|
||||
{
|
||||
const auto& y_scale_et = y_scale->get_element_type();
|
||||
const auto& y_scale_et = y_scale.get_element_type();
|
||||
CHECK_VALID_NODE(onnx_node,
|
||||
y_scale_et.is_static(),
|
||||
"\"y_scale\" input data type must be static.");
|
||||
@ -79,10 +78,10 @@ namespace ngraph
|
||||
return y_scale;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> validate_data(const Node& onnx_node,
|
||||
std::shared_ptr<ngraph::Node>& data)
|
||||
Output<ngraph::Node> validate_data(const Node& onnx_node,
|
||||
const Output<ngraph::Node>& data)
|
||||
{
|
||||
const auto& data_et = data->get_element_type();
|
||||
const auto& data_et = data.get_element_type();
|
||||
CHECK_VALID_NODE(onnx_node,
|
||||
data_et.is_static(),
|
||||
"\"x\" input data type must be static.");
|
||||
@ -120,8 +119,8 @@ namespace ngraph
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
|
||||
get_input_bands(const std::shared_ptr<ngraph::Node>& y_scale,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point,
|
||||
get_input_bands(const Output<ngraph::Node>& y_scale,
|
||||
const Output<ngraph::Node>& y_zero_point,
|
||||
const std::shared_ptr<ngraph::Node>& output_low,
|
||||
const std::shared_ptr<ngraph::Node>& output_high,
|
||||
const element::Type& data_type)
|
||||
@ -142,12 +141,12 @@ namespace ngraph
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node>
|
||||
make_fake_quantize(const std::shared_ptr<ngraph::Node>& y_scale,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point,
|
||||
const std::shared_ptr<ngraph::Node>& data)
|
||||
make_fake_quantize(const Output<ngraph::Node>& y_scale,
|
||||
const Output<ngraph::Node>& y_zero_point,
|
||||
const Output<ngraph::Node>& data)
|
||||
{
|
||||
const element::Type& destination_type = y_zero_point->get_element_type();
|
||||
const element::Type& data_type = data->get_element_type();
|
||||
const element::Type& destination_type = y_zero_point.get_element_type();
|
||||
const element::Type& data_type = data.get_element_type();
|
||||
|
||||
std::shared_ptr<ngraph::Node> output_low;
|
||||
std::shared_ptr<ngraph::Node> output_high;
|
||||
@ -166,16 +165,16 @@ namespace ngraph
|
||||
data, input_low, input_high, output_low, output_high, levels),
|
||||
destination_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace detail
|
||||
|
||||
namespace set_1
|
||||
{
|
||||
OutputVector quantize_linear(const Node& node)
|
||||
{
|
||||
OutputVector inputs{node.get_ng_inputs()};
|
||||
auto x = inputs.at(0).get_node_shared_ptr();
|
||||
auto y_scale = inputs.at(1).get_node_shared_ptr();
|
||||
auto x = inputs.at(0);
|
||||
auto y_scale = inputs.at(1);
|
||||
auto y_zero_point = detail::get_zero_point(inputs);
|
||||
|
||||
x = detail::validate_data(node, x);
|
||||
@ -191,21 +190,21 @@ namespace ngraph
|
||||
OutputVector quantize_linear(const Node& node)
|
||||
{
|
||||
OutputVector inputs{node.get_ng_inputs()};
|
||||
auto x = inputs.at(0).get_node_shared_ptr();
|
||||
auto y_scale = inputs.at(1).get_node_shared_ptr();
|
||||
auto x = inputs.at(0);
|
||||
auto y_scale = inputs.at(1);
|
||||
auto y_zero_point = detail::get_zero_point(inputs);
|
||||
|
||||
x = detail::validate_data(node, x);
|
||||
detail::validate_zero_point_type(node, y_zero_point);
|
||||
y_scale = detail::validate_scale(node, y_scale);
|
||||
|
||||
const auto& x_shape = x->get_output_partial_shape(0);
|
||||
const auto& x_shape = x.get_partial_shape();
|
||||
|
||||
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
|
||||
axis = normalize_axis(node.get_description(), axis, x_shape.rank());
|
||||
|
||||
const auto& y_scale_shape = y_scale->get_output_partial_shape(0);
|
||||
const auto& y_zero_point_shape = y_zero_point->get_output_partial_shape(0);
|
||||
const auto& y_scale_shape = y_scale.get_partial_shape();
|
||||
const auto& y_zero_point_shape = y_zero_point.get_partial_shape();
|
||||
|
||||
if (y_scale_shape.rank().is_static() &&
|
||||
y_scale_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&
|
||||
|
Loading…
Reference in New Issue
Block a user