[Transformations] FCTransformations: leftovers (#8062)
This commit is contained in:
parent
cb26462288
commit
243efa465e
@ -31,7 +31,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
|
||||
auto shape_a = fc_input_a.get_partial_shape();
|
||||
auto shape_b = fc_input_b.get_partial_shape();
|
||||
NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher
|
||||
NGRAPH_CHECK(shape_b.is_static());
|
||||
|
||||
auto rank_a = shape_a.rank().get_length();
|
||||
auto rank_b = shape_b.rank().get_length();
|
||||
@ -141,7 +141,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
|
||||
if (rank_b != 2) {
|
||||
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
|
||||
NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher
|
||||
NGRAPH_CHECK(K.is_static());
|
||||
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
|
||||
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
|
||||
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
|
||||
|
@ -57,12 +57,16 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
"Weights pshape must be static");
|
||||
const auto weights_shape = weights_pshape.to_shape();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
weights_pshape.size() > 0,
|
||||
"Weights rank must be greater than 0");
|
||||
|
||||
const auto o_channels = weights_pshape[0];
|
||||
if (input_size == 3) {
|
||||
const auto bias_shape = get_input_partial_shape(2);
|
||||
const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
bias_shape == expected_bias_shape,
|
||||
bias_shape.is_static() && bias_shape.compatible(expected_bias_shape),
|
||||
"Bias shape is incorrect. Current value is: ",
|
||||
bias_shape,
|
||||
", expected: ",
|
||||
@ -83,11 +87,13 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
}
|
||||
output_pshape.push_back(o_channels);
|
||||
|
||||
if (m_output_rank.is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_output_rank.is_static(),
|
||||
"Output rank must be static if activations rank is static.");
|
||||
|
||||
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
|
||||
output_pshape.insert(output_pshape.begin(), 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output_pshape = ngraph::PartialShape::dynamic();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user