[Transformations] FCTransformations: leftovers (#8062)
This commit is contained in:
parent
cb26462288
commit
243efa465e
@ -31,7 +31,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
|||||||
|
|
||||||
auto shape_a = fc_input_a.get_partial_shape();
|
auto shape_a = fc_input_a.get_partial_shape();
|
||||||
auto shape_b = fc_input_b.get_partial_shape();
|
auto shape_b = fc_input_b.get_partial_shape();
|
||||||
NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher
|
NGRAPH_CHECK(shape_b.is_static());
|
||||||
|
|
||||||
auto rank_a = shape_a.rank().get_length();
|
auto rank_a = shape_a.rank().get_length();
|
||||||
auto rank_b = shape_b.rank().get_length();
|
auto rank_b = shape_b.rank().get_length();
|
||||||
@ -141,7 +141,7 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
|||||||
|
|
||||||
if (rank_b != 2) {
|
if (rank_b != 2) {
|
||||||
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
|
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
|
||||||
NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher
|
NGRAPH_CHECK(K.is_static());
|
||||||
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
|
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
|
||||||
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
|
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
|
||||||
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
|
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
|
||||||
|
@ -57,12 +57,16 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
|||||||
"Weights pshape must be static");
|
"Weights pshape must be static");
|
||||||
const auto weights_shape = weights_pshape.to_shape();
|
const auto weights_shape = weights_pshape.to_shape();
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
weights_pshape.size() > 0,
|
||||||
|
"Weights rank must be greater than 0");
|
||||||
|
|
||||||
const auto o_channels = weights_pshape[0];
|
const auto o_channels = weights_pshape[0];
|
||||||
if (input_size == 3) {
|
if (input_size == 3) {
|
||||||
const auto bias_shape = get_input_partial_shape(2);
|
const auto bias_shape = get_input_partial_shape(2);
|
||||||
const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
|
const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
bias_shape == expected_bias_shape,
|
bias_shape.is_static() && bias_shape.compatible(expected_bias_shape),
|
||||||
"Bias shape is incorrect. Current value is: ",
|
"Bias shape is incorrect. Current value is: ",
|
||||||
bias_shape,
|
bias_shape,
|
||||||
", expected: ",
|
", expected: ",
|
||||||
@ -83,10 +87,12 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
|||||||
}
|
}
|
||||||
output_pshape.push_back(o_channels);
|
output_pshape.push_back(o_channels);
|
||||||
|
|
||||||
if (m_output_rank.is_static()) {
|
NODE_VALIDATION_CHECK(this,
|
||||||
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
|
m_output_rank.is_static(),
|
||||||
output_pshape.insert(output_pshape.begin(), 1);
|
"Output rank must be static if activations rank is static.");
|
||||||
}
|
|
||||||
|
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
|
||||||
|
output_pshape.insert(output_pshape.begin(), 1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output_pshape = ngraph::PartialShape::dynamic();
|
output_pshape = ngraph::PartialShape::dynamic();
|
||||||
|
Loading…
Reference in New Issue
Block a user