[PT FE] Support chunk->ListUnpack in quantized models (#18823)

* [PT FE] Support chunk->ListUnpack in quantized models

* Add comment

* Use type

* Fix build

* Update src/frontends/pytorch/src/op/list_unpack.cpp

* Update src/frontends/pytorch/src/utils.hpp
This commit is contained in:
Maxim Vafin 2023-07-27 21:10:21 +02:00 committed by GitHub
parent 1857c7a793
commit ab8b46165b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 171 additions and 95 deletions

View File

@ -25,7 +25,6 @@
#include "transforms/aten_index_put_replacer.hpp"
#include "transforms/aten_index_replacer.hpp"
#include "transforms/aten_stack_list_construct_replacer.hpp"
#include "transforms/dequantize_node_remover.hpp"
#include "transforms/dict_resolver.hpp"
#include "transforms/einsum_list_construct.hpp"
#include "transforms/index_loop_getitem_replacer.hpp"
@ -192,7 +191,6 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DequantizeNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_list_unpack(const NodeContext& context) {
const auto& outputs =
make_framework_node(context, "Lists are not supported yet and can be resolved only in specific cases.");
OutputVector res;
const auto& input = context.get_input(0);
const auto& input_node = input.get_node_shared_ptr();
const auto& quantized_node = input_node->input_value(0);
if (const auto& quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
if (const auto& chunk_node = cast_fw_node(input_node, "aten::chunk")) {
for (const auto& output : outputs) {
res.push_back(context.mark_node(std::make_shared<QuantizedPtNode>(quantized_pt_node->get_type(),
output,
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_dtype())));
}
return res;
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Unsupported operation type.");
}
} else {
return outputs;
}
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -85,6 +85,7 @@ OP_CONVERTER(translate_linalg_matrix_norm);
OP_CONVERTER(translate_linalg_vector_norm);
OP_CONVERTER(translate_linear);
OP_CONVERTER(translate_list_construct);
OP_CONVERTER(translate_list_unpack);
OP_CONVERTER(translate_log);
OP_CONVERTER(translate_log_softmax);
OP_CONVERTER(translate_log2);
@ -265,6 +266,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cosh>>},
{"aten::cumsum", op::translate_cumsum},
{"aten::detach", op::skip_node},
{"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed
{"aten::dim", op::translate_dim},
{"aten::div", op::translate_div},
{"aten::div_", op::inplace_op<op::translate_div>},
@ -449,6 +451,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"prim::If", op::translate_if},
{"prim::is_cuda", op::return_false_scalar},
{"prim::ListConstruct", op::translate_list_construct},
{"prim::ListUnpack", op::translate_list_unpack},
{"prim::Loop", op::translate_loop},
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
{"prim::requires_grad", op::return_false_scalar},

View File

@ -1,42 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "dequantize_node_remover.hpp"
#include <memory>
#include <utility>
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
DequantizeNodeRemover::DequantizeNodeRemover() {
auto dequantize_node = ov::pass::pattern::wrap_type<ov::frontend::pytorch::PtFrameworkNode>();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto dequantize_node = cast_fw_node(m.get_match_root(), "aten::dequantize");
if (!dequantize_node)
return false;
auto dequantized_input = dequantize_node->input_value(0);
dequantize_node->output(0).replace(dequantized_input);
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(dequantize_node,
"ov::frontend::pytorch::pass::DequantizeNodeRemover");
this->register_matcher(m, callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -1,29 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
/**
* Dequantize Node Remover
* Replacer finds the unconverted dequantize ops and removes them.
* This matches the behavior of OV's LPT.
*/
class DequantizeNodeRemover : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::DequantizeNodeRemover");
DequantizeNodeRemover();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -142,6 +142,89 @@ inline OutputVector skip_node(const NodeContext& context) {
}
} // namespace op
class DummyDecoder : public TorchDecoder {
public:
virtual Any const_input(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(const_input);
}
virtual const std::vector<size_t>& inputs() const override {
FRONT_END_NOT_IMPLEMENTED(inputs);
}
virtual const std::string& get_input_debug_name(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_debug_name);
}
virtual const std::string& get_input_signature_name(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_signature_name);
}
virtual PartialShape get_input_shape(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_shape);
}
virtual Any get_input_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_type);
}
virtual const std::vector<size_t>& get_input_transpose_order(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_transpose_order);
}
virtual const std::string& get_output_debug_name(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_debug_name);
}
virtual PartialShape get_output_shape(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_shape);
}
virtual Any get_output_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_type);
}
virtual const std::vector<size_t>& get_output_transpose_order(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_transpose_order);
}
virtual bool input_is_none(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(input_is_none);
}
virtual OutputVector try_decode_get_attr() const override {
FRONT_END_NOT_IMPLEMENTED(try_decode_get_attr);
}
virtual OutputVector as_constant() const override {
FRONT_END_NOT_IMPLEMENTED(as_constant);
}
virtual const std::string& as_string() const override {
FRONT_END_NOT_IMPLEMENTED(as_string);
}
virtual const std::string& get_op_type() const override {
FRONT_END_NOT_IMPLEMENTED(get_op_type);
}
virtual const std::string& get_schema() const override {
FRONT_END_NOT_IMPLEMENTED(get_schema);
}
virtual size_t num_of_outputs() const override {
FRONT_END_NOT_IMPLEMENTED(num_of_outputs);
}
virtual const std::vector<size_t>& outputs() const override {
FRONT_END_NOT_IMPLEMENTED(outputs);
}
virtual size_t output(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(output);
}
virtual std::shared_ptr<Node> mark_node(std::shared_ptr<Node> ov_node) const override {
FRONT_END_NOT_IMPLEMENTED(mark_node);
}
virtual size_t get_subgraph_size() const override {
FRONT_END_NOT_IMPLEMENTED(get_subgraph_size);
}
virtual void visit_subgraph(std::function<void(std::shared_ptr<TorchDecoder>)> node_visitor) const override {
FRONT_END_NOT_IMPLEMENTED(visit_subgraph);
}
virtual std::shared_ptr<TorchDecoder> get_subgraph_decoder(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_subgraph_decoder);
}
virtual bool may_produce_alias(size_t in_index, size_t out_index) const override {
FRONT_END_NOT_IMPLEMENTED(may_produce_alias);
}
virtual OutputVector inlined_inputs(size_t start_index) const override {
FRONT_END_NOT_IMPLEMENTED(inlined_inputs);
}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -57,7 +57,6 @@ Output<Node> quantize(const NodeContext& context,
std::make_shared<v0::FakeQuantize>(input_convert, bound_low, bound_high, bound_low, bound_high, levels));
return context.mark_node(std::make_shared<QuantizedPtNode>(quantization_type,
context,
quantized_input,
scale_convert,
zero_point_convert,
@ -109,7 +108,6 @@ Output<Node> quantize(const NodeContext& context,
std::make_shared<v0::FakeQuantize>(input_convert, bound_low, bound_high, bound_low, bound_high, levels));
return context.mark_node(std::make_shared<QuantizedPtNode>(quantization_type,
context,
quantized_input,
scale_bc,
zero_point_bc,

View File

@ -11,6 +11,31 @@ namespace ov {
namespace frontend {
namespace pytorch {
class QuantizedDecoder : public DummyDecoder {
public:
QuantizedDecoder(const Output<Node>& input) : m_qinput(input) {}
virtual PartialShape get_output_shape(size_t index) const override {
return m_qinput.get_partial_shape();
}
virtual const std::string& get_op_type() const override {
return m_op_type;
}
virtual const std::string& get_schema() const override {
return m_schema;
}
virtual size_t num_of_outputs() const override {
return 1;
}
virtual size_t get_subgraph_size() const override {
return 0;
}
private:
const Output<Node> m_qinput;
const std::string m_op_type = "QuantizedPtNode";
const std::string m_schema = "NONE";
};
enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };
class QuantizedPtNode : public PtFrameworkNode {
@ -21,12 +46,11 @@ public:
static constexpr const char* quantize_per_channel = "quantize_per_channel";
QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context,
const Output<Node> input,
const Output<Node> scale,
const Output<Node> zero_point,
element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point}, 1, false),
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& dtype)
: PtFrameworkNode(std::make_shared<QuantizedDecoder>(input), {input, scale, zero_point}, 1, false),
type(type) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
@ -41,13 +65,12 @@ public:
}
QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context,
const Output<Node> input,
const Output<Node> scale,
const Output<Node> zero_point,
const Output<Node> axis,
element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point, axis}, 1, false),
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
const element::Type& dtype)
: PtFrameworkNode(std::make_shared<QuantizedDecoder>(input), {input, scale, zero_point, axis}, 1, false),
type(type) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
@ -129,13 +152,11 @@ OutputVector quantizable_op(const NodeContext& context) {
auto translation_res = T(context);
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() > out_idx, "Not enough outputs to apply quantization.");
if (const auto quantized_pt_node = cast_quantized_fw_node(context.get_input(in_idx).get_node_shared_ptr())) {
return {quantize(context,
translation_res[out_idx],
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(),
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type())};
return {context.mark_node(std::make_shared<QuantizedPtNode>(quantized_pt_node->get_type(),
translation_res[out_idx],
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_dtype()))};
}
return translation_res;
}