[LPT] Security fixes (#10381)

This commit is contained in:
Vladislav Golubev 2022-02-16 10:31:17 +03:00 committed by GitHub
parent cbb5dff9c1
commit fa4246d531
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 13 deletions

View File

@ -317,7 +317,7 @@ std::shared_ptr<Node> foldConvert(const Output<Node>& node, const element::Type
template <typename T, typename... Args>
std::shared_ptr<Node> fold_reshape(Args&&... args) {
std::shared_ptr<Node> node = std::make_shared<T>(std::forward<Args>(args)...);
std::shared_ptr<Node> node = std::make_shared<T>(args...);
if (node->get_output_size() == 1) {
// issue #57985: remove fold_reshape & reuse nGraph implementation
const auto values = ov::as_type_ptr<opset1::Constant>(node->input_value(1).get_node_shared_ptr())->template cast_vector<int64_t>();
@ -325,7 +325,6 @@ std::shared_ptr<Node> fold_reshape(Args&&... args) {
return fold<opset1::Reshape>(std::forward<Args>(args)...);
}
OutputVector folded;
if (ov::is_type<opset1::Constant>(node->input_value(0).get_node_shared_ptr()) &&
ov::is_type<opset1::Constant>(node->input_value(1).get_node_shared_ptr())) {
return std::make_shared<opset1::Constant>(

View File

@ -34,8 +34,8 @@ FakeQuantizeTransformation::FakeQuantizeTransformation(const Params& params) : L
}
bool FakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
std::shared_ptr<opset1::FakeQuantize> layer = std::dynamic_pointer_cast<opset1::FakeQuantize>(m.get_match_root());
if (!QuantizationDetails::outputLayoutIsSupported(layer)) {
const auto layer = ov::as_type_ptr<opset1::FakeQuantize>(m.get_match_root());
if (!layer || !QuantizationDetails::outputLayoutIsSupported(layer)) {
return false;
}

View File

@ -273,7 +273,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto layer = ov::as_type_ptr<opset1::FakeQuantize>(m.get_match_root());
if (!NetworkHelper::isQuantizeSupported(layer)) {
if (!layer || !NetworkHelper::isQuantizeSupported(layer)) {
return false;
}

View File

@ -31,7 +31,10 @@ FuseFakeQuantizeTransformation::FuseFakeQuantizeTransformation(const Params& par
}
bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
std::shared_ptr<opset1::FakeQuantize> fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
auto fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
if (!fakeQuantize)
return false;
do {
fakeQuantize = handle(context, fakeQuantize);
} while (fakeQuantize != nullptr);

View File

@ -104,12 +104,12 @@ void make_matcher_type_relaxed(ngraph::pass::GraphRewrite* transformation) {
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
auto l_node = std::dynamic_pointer_cast<BaseOp>(m.get_match_root());
if (!l_node) {
THROW_TRANSFORMATION_EXCEPTION << "unexpected operation type for type relaxed conversion";
}
if (std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(l_node)) {
return false;
}
if (!l_node) {
THROW_IE_LPT_EXCEPTION(*l_node) << "unexpected operation type";
}
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "LowPrecisionTypeRelaxedMatcher");

View File

@ -116,7 +116,7 @@ std::shared_ptr<opset1::Constant> NetworkHelper::foldDequantizationConstant(
const auto result = ov::as_type_ptr<opset1::Constant>(outputs[outIdx].get_node_shared_ptr());
if (result == nullptr) {
THROW_IE_LPT_EXCEPTION(*result) << "result of constant folding is not constant";
THROW_TRANSFORMATION_EXCEPTION << "result of constant folding is not constant";
}
return result;
@ -441,9 +441,9 @@ std::vector<size_t> NetworkHelper::updateReshapeValues(
}
std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter(std::shared_ptr<Node> node) {
std::shared_ptr<ngraph::opset1::Multiply> multiply = ov::as_type_ptr<opset1::Multiply>(std::move(node));
const auto multiply = ov::as_type_ptr<opset1::Multiply>(std::move(node));
if (!multiply) {
THROW_IE_LPT_EXCEPTION(*multiply) << "Unexpected operation type";
THROW_TRANSFORMATION_EXCEPTION << "Unexpected operation type in the optimizeMultipliesAfter method";
}
if (multiply->output(0).get_target_inputs().size() == 1) {
@ -1708,6 +1708,10 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
dequantization.convert->get_output_element_type(0) :
deqPrecision;
parent = std::make_shared<opset1::Convert>(parent, convertOutputPrecision);
if (dequantization.convert == nullptr) {
THROW_TRANSFORMATION_EXCEPTION << "dequantization convert is absent";
}
parent->set_friendly_name(dequantization.convert->get_friendly_name() + "_" + std::to_string(i + 1));
ngraph::copy_runtime_info(dequantization.convert, parent);
}
@ -1761,6 +1765,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
}
auto newOperation = operation->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end()));
NetworkHelper::copyInfo(operation, newOperation);
if (dequantization.multiply == nullptr) {
THROW_TRANSFORMATION_EXCEPTION << "dequantization operations must end with multiply";
}
replace_node(dequantization.multiply, newOperation);
if (const auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation)) {

View File

@ -112,7 +112,7 @@ ngraph::pass::low_precision::PullReshapeThroughDequantization::PullReshapeThroug
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
const auto& opsMap = m.get_pattern_value_map();
auto reshape = opsMap.find(reshapeWrapper)->second.get_node()->shared_from_this();
auto reshape = opsMap.at(reshapeWrapper).get_node_shared_ptr();
auto child = reshape->get_output_target_inputs(0).begin()->get_node();
if (ov::is_type<opset1::GroupConvolution>(child)) {