Fixed static analysis issues (transformations) (#3281)
This commit is contained in:
parent
7e2305543a
commit
a555efe287
@ -112,7 +112,10 @@ ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
|
||||
if (seq_axis == 1) {
|
||||
ngraph::replace_node(lstm_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
|
||||
} else {
|
||||
auto transpose_after = lstm_sequence->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
|
||||
const auto &lstm_target_inputs = lstm_sequence->output(0).get_target_inputs();
|
||||
if (lstm_target_inputs.empty())
|
||||
return false;
|
||||
auto transpose_after = lstm_target_inputs.begin()->get_node()->shared_from_this();
|
||||
ngraph::replace_node(transpose_after, unsqueeze_1);
|
||||
ngraph::replace_node(lstm_sequence, {lstm_sequence_ie->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
|
||||
}
|
||||
@ -179,7 +182,10 @@ ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
|
||||
if (seq_axis == 1) {
|
||||
ngraph::replace_node(gru_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
|
||||
} else {
|
||||
auto transpose_after = gru_sequence->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
|
||||
const auto &gru_target_inputs = gru_sequence->output(0).get_target_inputs();
|
||||
if (gru_target_inputs.empty())
|
||||
return false;
|
||||
auto transpose_after = gru_target_inputs.begin()->get_node()->shared_from_this();
|
||||
ngraph::replace_node(transpose_after, unsqueeze_1);
|
||||
ngraph::replace_node(gru_sequence, {gru_sequence_ie->output(0), unsqueeze_2->output(0)});
|
||||
}
|
||||
@ -247,7 +253,10 @@ ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
|
||||
if (seq_axis == 1) {
|
||||
ngraph::replace_node(rnn_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
|
||||
} else {
|
||||
auto transpose_after = rnn_sequence->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
|
||||
const auto &rnn_target_inputs = rnn_sequence->output(0).get_target_inputs();
|
||||
if (rnn_target_inputs.empty())
|
||||
return false;
|
||||
auto transpose_after = rnn_target_inputs.begin()->get_node()->shared_from_this();
|
||||
ngraph::replace_node(transpose_after, unsqueeze_1);
|
||||
ngraph::replace_node(rnn_sequence, {rnn_sequence_ie->output(0), unsqueeze_2->output(0)});
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ bool ConcatTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noe
|
||||
|
||||
bool ConcatTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
|
||||
std::shared_ptr<opset1::Concat> concat = as_type_ptr<opset1::Concat>(layer);
|
||||
return concat->get_axis() == 1ul;
|
||||
return concat && concat->get_axis() == 1ul;
|
||||
}
|
||||
|
||||
|
||||
|
@ -194,6 +194,8 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node) {
|
||||
}
|
||||
|
||||
auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
|
||||
if (unsqueeze == nullptr)
|
||||
return false;
|
||||
auto input = unsqueeze->input_value(0).get_node_shared_ptr();
|
||||
auto squeeze = as_type_ptr<opset3::Squeeze>(input);
|
||||
auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
|
||||
@ -260,6 +262,8 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
|
||||
}
|
||||
|
||||
auto squeeze = as_type_ptr<opset3::Squeeze>(node);
|
||||
if (squeeze == nullptr)
|
||||
return false;
|
||||
auto input = squeeze->input_value(0).get_node_shared_ptr();
|
||||
auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
|
||||
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
|
@ -349,6 +349,8 @@ static std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::C
|
||||
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
if (dst_data == nullptr)
|
||||
throw ngraph_error("Can't get destination data pointer");
|
||||
|
||||
std::vector<dst_type> final_data;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
|
@ -132,6 +132,8 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe
|
||||
|
||||
auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
|
||||
const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(found_cell);
|
||||
if (lstm_cell == nullptr)
|
||||
return false;
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
@ -283,6 +285,8 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ
|
||||
}
|
||||
|
||||
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::RNNCell>(pattern_map[cell]);
|
||||
if (rnn_cell == nullptr)
|
||||
return false;
|
||||
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
@ -434,6 +438,8 @@ ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequ
|
||||
}
|
||||
|
||||
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::GRUCell>(pattern_map[cell]);
|
||||
if (rnn_cell == nullptr)
|
||||
return false;
|
||||
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
|
@ -20,7 +20,7 @@ ngraph::pass::LogSoftmaxDecomposition::LogSoftmaxDecomposition() {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto log_softmax_node = std::dynamic_pointer_cast<ngraph::opset5::LogSoftmax>(pattern_to_output.at(log_softmax).get_node_shared_ptr());
|
||||
|
||||
if (m_transformation_callback(log_softmax_node)) {
|
||||
if (log_softmax_node == nullptr || m_transformation_callback(log_softmax_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user