[Transformations] FCTransformations: leftovers (#8062)

This commit is contained in:
Vladislav Golubev 2021-10-21 10:38:14 +03:00 committed by GitHub
parent cb26462288
commit 243efa465e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -31,7 +31,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
auto shape_a = fc_input_a.get_partial_shape(); auto shape_a = fc_input_a.get_partial_shape();
auto shape_b = fc_input_b.get_partial_shape(); auto shape_b = fc_input_b.get_partial_shape();
NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher NGRAPH_CHECK(shape_b.is_static());
auto rank_a = shape_a.rank().get_length(); auto rank_a = shape_a.rank().get_length();
auto rank_b = shape_b.rank().get_length(); auto rank_b = shape_b.rank().get_length();
@ -141,7 +141,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
if (rank_b != 2) { if (rank_b != 2) {
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1); ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher NGRAPH_CHECK(K.is_static());
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) }; std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values); auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false); fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);

View File

@ -57,12 +57,16 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
"Weights pshape must be static"); "Weights pshape must be static");
const auto weights_shape = weights_pshape.to_shape(); const auto weights_shape = weights_pshape.to_shape();
NODE_VALIDATION_CHECK(this,
weights_pshape.size() > 0,
"Weights rank must be greater than 0");
const auto o_channels = weights_pshape[0]; const auto o_channels = weights_pshape[0];
if (input_size == 3) { if (input_size == 3) {
const auto bias_shape = get_input_partial_shape(2); const auto bias_shape = get_input_partial_shape(2);
const auto expected_bias_shape = ngraph::PartialShape{ o_channels }; const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
bias_shape == expected_bias_shape, bias_shape.is_static() && bias_shape.compatible(expected_bias_shape),
"Bias shape is incorrect. Current value is: ", "Bias shape is incorrect. Current value is: ",
bias_shape, bias_shape,
", expected: ", ", expected: ",
@ -83,10 +87,12 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
} }
output_pshape.push_back(o_channels); output_pshape.push_back(o_channels);
if (m_output_rank.is_static()) { NODE_VALIDATION_CHECK(this,
while (output_pshape.rank().get_length() < m_output_rank.get_length()) { m_output_rank.is_static(),
output_pshape.insert(output_pshape.begin(), 1); "Output rank must be static if activations rank is static.");
}
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
output_pshape.insert(output_pshape.begin(), 1);
} }
} else { } else {
output_pshape = ngraph::PartialShape::dynamic(); output_pshape = ngraph::PartialShape::dynamic();