Introduce Quantize-Dequantize to FakeQuantize transformation (#1849)
* Introduce Quantize-Dequantize to FakeQuantize transformation * Revert changes in DequantizeLinear * apply code format * Changes after review: - description for transformation - remove NGRAPH_CHECK and move some checks from callback to predicates in pattern - check if out_low/high are broadcastable for FQ's first input - fix params to copy_runtime_info * Add type_matches and type_matches_any predicates * Use get_single_value * Changes after review: - add brief description of transformation - use get_pattern_value_map instead of get_pattern_map - change opset1 to opset4 - fix params to copy_runtime_info * Check result of dynamic_pointer_cast
This commit is contained in:
parent
4673dc9b9c
commit
a6076a1fd6
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertQuantizeDequantize;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertQuantizeDequantize transformation replaces following graph:
|
||||
* FakeQuantize->Convert->Convert->Subtract->Multiply with a single FakeQuantize.
|
||||
* Restrictions:
|
||||
* - quantized data type must be i8 or u8
|
||||
* - 'levels' attribute to FakeQuantize must be equal to 256
|
||||
* - (output_low, output_high) must be (-128, 127) or (0, 256) (depends on sign of quantized data type)
|
||||
* - 'zero_point' and 'scale' must be broadcastable to FakeQuantize's output
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertQuantizeDequantize: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertQuantizeDequantize();
|
||||
};
|
@ -19,6 +19,7 @@
|
||||
#include "transformations/softplus_fusion.hpp"
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
#include "transformations/convert_quantize_dequantize.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -33,6 +34,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::ConvertPriorBox>(); // WA: ConvertPriorBox must be executed before CF
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
|
||||
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
|
||||
|
@ -0,0 +1,153 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
|
||||
// ConvertQuantizeDequantize converts Quantize/Dequantize pair to a single FakeQuantize.
|
||||
// Since Quantize is decomposed to FakeQuantize and Dequantize is decomposed to Subtract->Multiply,
|
||||
// the full pattern to match is presented on the left hand side of the graph below.
|
||||
// On the right hand side is the graph after transformation.
|
||||
// Currently transformation supports only i8 and u8 quantized data type.
|
||||
// That implies 'levels' attribute to be 256, as well as (output_low, output_high) be (-128, 127) or (0, 255) (depends on sign of the quantized data type).
|
||||
// Another limitation is that 'zero_point' and 'scale' have to be broadcastable to the output of FakeQuantize.
|
||||
//
|
||||
//
|
||||
// | | | | |
|
||||
// | | | | |
|
||||
// v v v v v
|
||||
// +------------+
|
||||
// |FakeQuantize|
|
||||
// +------------+
|
||||
// |
|
||||
// v
|
||||
// +---------------------+
|
||||
// | Convert |
|
||||
// |(e.g. from f32 to u8)|
|
||||
// +---------+-----------+ | | | | |
|
||||
// | | | | | |
|
||||
// v v v v v v
|
||||
// +---------------------+ +------------+
|
||||
// | Convert | ====> |FakeQuantize|
|
||||
// | (from u8 to f32) | +------------+
|
||||
// +---------+-----------+ |
|
||||
// | v
|
||||
// v
|
||||
// +----------+ +------------+
|
||||
// |zero point|--->| Subtract |
|
||||
// +----------+ +-----+------+
|
||||
// |
|
||||
// v
|
||||
// +---------+ +------------+
|
||||
// | scale |--->| Multiply |
|
||||
// +---------+ +-----+------+
|
||||
// |
|
||||
// v
|
||||
//
|
||||
|
||||
|
||||
ngraph::pass::ConvertQuantizeDequantize::ConvertQuantizeDequantize() {
|
||||
auto data_pattern = ngraph::pattern::any_input();
|
||||
auto input_low_pattern = ngraph::pattern::any_input();
|
||||
auto input_high_pattern = ngraph::pattern::any_input();
|
||||
auto output_low_pattern = ngraph::pattern::wrap_type<opset4::Constant>();
|
||||
auto output_high_pattern = ngraph::pattern::wrap_type<opset4::Constant>();
|
||||
auto fq_pattern = ngraph::pattern::wrap_type<opset4::FakeQuantize>({data_pattern, input_low_pattern,
|
||||
input_high_pattern, output_low_pattern,
|
||||
output_high_pattern});
|
||||
auto convert1_pattern = ngraph::pattern::wrap_type<opset4::Convert>({fq_pattern}, pattern::type_matches_any({element::i8, element::u8}));
|
||||
auto convert2_pattern = ngraph::pattern::wrap_type<opset4::Convert>({convert1_pattern}, pattern::type_matches(element::f32));
|
||||
auto zero_point_pattern = ngraph::pattern::any_input();
|
||||
auto sub_pattern = ngraph::pattern::wrap_type<opset4::Subtract>({convert2_pattern, zero_point_pattern}, pattern::consumers_count(1));
|
||||
auto scale_pattern = ngraph::pattern::any_input();
|
||||
auto mul_pattern = ngraph::pattern::wrap_type<opset4::Multiply>({sub_pattern, scale_pattern});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto input_low = pattern_map[input_low_pattern];
|
||||
auto input_high = pattern_map[input_high_pattern];
|
||||
auto output_low = std::dynamic_pointer_cast<opset4::Constant>(pattern_map[output_low_pattern].get_node_shared_ptr());
|
||||
if (!output_low)
|
||||
return false;
|
||||
auto output_high = std::dynamic_pointer_cast<opset4::Constant>(pattern_map[output_high_pattern].get_node_shared_ptr());
|
||||
if (!output_high)
|
||||
return false;
|
||||
auto fq = std::dynamic_pointer_cast<opset4::FakeQuantize>(pattern_map[fq_pattern].get_node_shared_ptr());
|
||||
if (!fq)
|
||||
return false;
|
||||
auto zero_point = pattern_map[zero_point_pattern];
|
||||
auto scale = pattern_map[scale_pattern];
|
||||
auto convert1 = pattern_map[convert1_pattern];
|
||||
auto convert2 = pattern_map[convert2_pattern];
|
||||
auto mul = pattern_map[mul_pattern].get_node_shared_ptr();
|
||||
|
||||
// convert1 and convert2 should have only one input
|
||||
if (convert1.get_target_inputs().size() != 1)
|
||||
return false;
|
||||
if (convert2.get_target_inputs().size() != 1)
|
||||
return false;
|
||||
|
||||
// we support only i8 or u8 so 'levels' attribute must be 256
|
||||
size_t levels = fq->get_levels();
|
||||
if (levels != 256)
|
||||
return false;
|
||||
|
||||
// check if (out_low_val, out_high_val) is (-128, 127) or (0, 255)
|
||||
float out_low_val;
|
||||
if (!op::util::get_single_value(output_low, out_low_val))
|
||||
return false;
|
||||
float out_high_val;
|
||||
if (!op::util::get_single_value(output_high, out_high_val))
|
||||
return false;
|
||||
const auto& type = convert1.get_element_type();
|
||||
switch (type) {
|
||||
case element::Type_t::i8:
|
||||
if (out_low_val != -128 || out_high_val != 127)
|
||||
return false;
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
if (out_low_val != 0 || out_high_val != 255)
|
||||
return false;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_out_low = std::make_shared<ngraph::opset4::Multiply>(
|
||||
std::make_shared<ngraph::opset4::Subtract>(output_low, zero_point), scale);
|
||||
auto new_out_high = std::make_shared<ngraph::opset4::Multiply>(
|
||||
std::make_shared<ngraph::opset4::Subtract>(output_high, zero_point), scale);
|
||||
|
||||
// check if new_out_low/high shapes are broadcastable to FQ's input
|
||||
auto data_shape = data.get_partial_shape();
|
||||
if (data_shape.rank().is_dynamic())
|
||||
return false;
|
||||
auto out_low_shape = new_out_low->get_output_partial_shape(0);
|
||||
if (out_low_shape.rank().is_dynamic() || out_low_shape.rank().get_length() > data_shape.rank().get_length())
|
||||
return false;
|
||||
auto out_high_shape = new_out_high->get_output_partial_shape(0);
|
||||
if (out_high_shape.rank().is_dynamic() || out_high_shape.rank().get_length() > data_shape.rank().get_length())
|
||||
return false;
|
||||
|
||||
auto new_fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, input_low, input_high, new_out_low, new_out_high, levels);
|
||||
new_fq->set_friendly_name(mul->get_friendly_name());
|
||||
|
||||
copy_runtime_info({fq, convert1.get_node_shared_ptr(), convert2.get_node_shared_ptr()}, new_fq);
|
||||
replace_node(mul, new_fq);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_pattern, "ConvertQuantizeDequantize");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,205 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <transformations/convert_quantize_dequantize.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<Function> create_q_dq_function(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
|
||||
const Shape& zero_point_shape, std::vector<T> zero_point_values,
|
||||
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, data_shape);
|
||||
auto input_low = opset1::Constant::create(element::f32, Shape{}, {in_low});
|
||||
auto input_high = opset1::Constant::create(element::f32, Shape{}, {in_high});
|
||||
auto output_low = opset1::Constant::create(element::f32, Shape{}, {out_low});
|
||||
auto output_high = opset1::Constant::create(element::f32, Shape{}, {out_high});
|
||||
auto fq = std::make_shared<opset1::FakeQuantize>(data, input_low,
|
||||
input_high, output_low,
|
||||
output_high, levels);
|
||||
auto convert1 = std::make_shared<opset1::Convert>(fq, element::from<T>());
|
||||
auto convert2 = std::make_shared<opset1::Convert>(convert1, element::f32);
|
||||
auto zero_point = std::make_shared<opset1::Convert>(opset1::Constant::create(element::from<T>(), zero_point_shape, zero_point_values), element::f32);
|
||||
auto sub = std::make_shared<opset1::Subtract>(convert2, zero_point);
|
||||
auto scale = opset1::Constant::create(element::f32, scale_shape, scale_values);
|
||||
auto mul = std::make_shared<opset1::Multiply>(sub, scale);
|
||||
|
||||
return std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void positive_test(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
|
||||
const Shape& zero_point_shape, std::vector<T> zero_point_values,
|
||||
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
f = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::ConvertQuantizeDequantize>();
|
||||
m.register_pass<pass::ConstantFolding>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, data_shape);
|
||||
auto input_low = opset1::Constant::create(element::f32, Shape{}, {in_low});
|
||||
auto input_high = opset1::Constant::create(element::f32, Shape{}, {in_high});
|
||||
auto output_low = opset1::Constant::create(element::f32, Shape{}, {(out_low - zero_point_values[0]) * scale_values[0]});
|
||||
auto output_high = opset1::Constant::create(element::f32, Shape{}, {(out_high - zero_point_values[0]) * scale_values[0]});
|
||||
auto fq = std::make_shared<opset1::FakeQuantize>(data, input_low,
|
||||
input_high, output_low,
|
||||
output_high, levels);
|
||||
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeINT8) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
float out_low = -128;
|
||||
float out_high = 127;
|
||||
Shape zero_point_shape{};
|
||||
std::vector<int8_t> zero_point_values{2};
|
||||
Shape scale_shape{};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 256;
|
||||
|
||||
positive_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeUINT8) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
float out_low = 0;
|
||||
float out_high = 255;
|
||||
Shape zero_point_shape{};
|
||||
std::vector<uint8_t> zero_point_values{2};
|
||||
Shape scale_shape{};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 256;
|
||||
|
||||
positive_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void negative_test(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
|
||||
const Shape& zero_point_shape, std::vector<T> zero_point_values,
|
||||
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
f = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::ConvertQuantizeDequantize>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
// negative test so the transformation does not fire and reference is the same graph as original
|
||||
f_ref = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeZeroPointNotBroadcastable) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
float out_low = -128;
|
||||
float out_high = 127;
|
||||
Shape zero_point_shape{1, 1, 1, 1};
|
||||
std::vector<int8_t> zero_point_values{2};
|
||||
Shape scale_shape{1};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 256;
|
||||
|
||||
negative_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeScaleNotBroadcastable) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
float out_low = -128;
|
||||
float out_high = 127;
|
||||
Shape zero_point_shape{};
|
||||
std::vector<int8_t> zero_point_values{2};
|
||||
Shape scale_shape{1, 1, 1, 1};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 256;
|
||||
|
||||
negative_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeInvalidLevels) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
float out_low = -128;
|
||||
float out_high = 127;
|
||||
Shape zero_point_shape{};
|
||||
std::vector<int8_t> zero_point_values{2};
|
||||
Shape scale_shape{};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 127;
|
||||
|
||||
negative_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertQuantizeDequantizeInvalidOutLowOutHigh) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
Shape data_shape{3, 1, 2};
|
||||
float in_low = 0;
|
||||
float in_high = 5;
|
||||
// (-128, 127) are invalid for uin8_t data type
|
||||
float out_low = -128;
|
||||
float out_high = 127;
|
||||
Shape zero_point_shape{};
|
||||
std::vector<uint8_t> zero_point_values{2};
|
||||
Shape scale_shape{};
|
||||
std::vector<float> scale_values{3};
|
||||
size_t levels = 256;
|
||||
|
||||
negative_test(data_shape, in_low, in_high, out_low, out_high,
|
||||
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
|
||||
}
|
@ -61,6 +61,12 @@ namespace ngraph
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> has_static_shape();
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
|
||||
|
||||
namespace op
|
||||
{
|
||||
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
|
||||
|
@ -94,5 +94,21 @@ namespace ngraph
|
||||
return
|
||||
[=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
|
||||
}
|
||||
|
||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type)
|
||||
{
|
||||
return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };
|
||||
}
|
||||
|
||||
std::function<bool(Output<Node>)>
|
||||
type_matches_any(const std::vector<element::Type>& expected_types)
|
||||
{
|
||||
return [=](Output<Node> output) -> bool {
|
||||
const auto& output_type = output.get_element_type();
|
||||
return std::any_of(expected_types.begin(),
|
||||
expected_types.end(),
|
||||
[=](element::Type type) { return type == output_type; });
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
108
ngraph/test/models/onnx/quant_dequant_pattern.prototxt
Normal file
108
ngraph/test/models/onnx/quant_dequant_pattern.prototxt
Normal file
@ -0,0 +1,108 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "scale"
|
||||
input: "zero_point"
|
||||
output: "quantization_out"
|
||||
name: "quantization"
|
||||
op_type: "QuantizeLinear"
|
||||
}
|
||||
node {
|
||||
input: "quantization_out"
|
||||
input: "scale"
|
||||
input: "zero_point"
|
||||
output: "dequantization_out"
|
||||
name: "dequantization"
|
||||
op_type: "DequantizeLinear"
|
||||
}
|
||||
node {
|
||||
input: "dequantization_out"
|
||||
input: "x"
|
||||
output: "mul_out"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
|
||||
name: "test_graph"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 3
|
||||
shape {
|
||||
dim {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output {
|
||||
name: "mul_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 0
|
||||
data_type: 1
|
||||
name: "scale"
|
||||
float_data: 3
|
||||
}
|
||||
initializer {
|
||||
dims: 0
|
||||
data_type: 2
|
||||
name: "zero_point"
|
||||
int32_data: 10
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
125
ngraph/test/models/onnx/quant_dequant_pattern_axis.prototxt
Normal file
125
ngraph/test/models/onnx/quant_dequant_pattern_axis.prototxt
Normal file
@ -0,0 +1,125 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "scale"
|
||||
input: "zero_point"
|
||||
output: "quantization_out"
|
||||
name: "quantization"
|
||||
op_type: "QuantizeLinear"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "quantization_out"
|
||||
input: "scale"
|
||||
input: "zero_point"
|
||||
output: "dequantization_out"
|
||||
name: "dequantization"
|
||||
op_type: "DequantizeLinear"
|
||||
}
|
||||
node {
|
||||
input: "dequantization_out"
|
||||
input: "x"
|
||||
output: "mul_out"
|
||||
name: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
|
||||
name: "test_graph"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 3
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output {
|
||||
name: "mul_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 1
|
||||
name: "scale"
|
||||
float_data: 2
|
||||
float_data: 3
|
||||
float_data: 4
|
||||
}
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 2
|
||||
name: "zero_point"
|
||||
int32_data: 10
|
||||
int32_data: 20
|
||||
int32_data: 30
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -2447,3 +2447,30 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_empty_initializers_handling)
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, quant_dequant_pattern)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_dequant_pattern.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// scale == 3.0
|
||||
// zero point == 10
|
||||
test_case.add_input<float>({9.0, 10.0, 15.0, 20.0, 30.0});
|
||||
test_case.add_input<float>({1});
|
||||
test_case.add_expected_output<float>(Shape{5}, {9.0, 9.0, 15.0, 21.0, 30.0});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, quant_dequant_pattern_axis)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_dequant_pattern_axis.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// axis = 1
|
||||
// scale == {2.0, 3.0, 4.0}
|
||||
// zero point == {10, 20, 30}
|
||||
test_case.add_input<float>({1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0, 50.0, 100.0});
|
||||
test_case.add_expected_output<float>(Shape{3, 3}, {0, 3, 4, 10, 21, 32, 40, 51, 100});
|
||||
test_case.add_input<float>({1});
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user