Fix Assign, Matmul ops for correct work with keep_shape_ops feature for Kaldi models (#1885)

* fix assign and matmul ops

* fix unit test

* static shapes by default for kaldi
This commit is contained in:
Ivan Tikhonov 2020-08-24 14:11:01 +03:00 committed by GitHub
parent 2acf3f27e1
commit c74643f6b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 10 deletions

View File

@ -225,7 +225,7 @@ def prepare_ir(argv: argparse.Namespace):
from mo.front.mxnet.register_custom_ops import get_front_classes
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
elif is_kaldi:
argv.static_shape = True # ticket #36794
argv.static_shape = True
from mo.front.kaldi.register_custom_ops import get_front_classes
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
elif is_onnx:

View File

@ -61,11 +61,19 @@ void op::v3::Assign::validate_and_infer_types()
"Variables identifiers are inconsistent.");
NODE_VALIDATION_CHECK(
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
NODE_VALIDATION_CHECK(this,
output_shape == variable_info.data_shape,
"Variables output shapes are inconsistent.");
set_output_type(0, arg_t, output_shape);
if (output_shape.is_static() && variable_info.data_shape.is_static())
{
NODE_VALIDATION_CHECK(this,
output_shape == variable_info.data_shape,
"Variables output shapes are inconsistent.");
set_output_type(0, arg_t, output_shape);
}
else
{
set_output_type(0, arg_t, PartialShape::dynamic());
}
}
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const

View File

@ -69,6 +69,10 @@ void op::MatMul::pre_validate_and_infer_types()
Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
set_output_type(0, result_et, PartialShape::dynamic(max_rank));
}
else
{
set_output_type(0, result_et, PartialShape::dynamic());
}
}
OutputVector op::MatMul::decompose_op() const

View File

@ -68,18 +68,18 @@ TEST(op_eval, matmul_dynamic_0_elem_arg)
std::vector<std::vector<Shape>> shapes{{Shape{2, 0}, Shape{0, 2}, Shape{2, 2}},
{Shape{0, 2}, Shape{2, 0}, Shape{0, 0}}};
std::vector<std::vector<int32_t>> arg_inputs{{}, {}};
std::vector<std::vector<int32_t>> expected_result{{0, 0, 0, 0}, {}};
std::vector<std::vector<float>> arg_inputs{{}, {}};
std::vector<std::vector<float>> expected_result{{0, 0, 0, 0}, {}};
for (size_t i = 0; i < arg_inputs.size(); i++)
{
auto result = make_shared<HostTensor>();
ASSERT_TRUE(
fun->evaluate({result},
{make_host_tensor<element::Type_t::i32>(shapes[i][0], arg_inputs[i]),
make_host_tensor<element::Type_t::i32>(shapes[i][1], arg_inputs[i])}));
{make_host_tensor<element::Type_t::f32>(shapes[i][0], arg_inputs[i]),
make_host_tensor<element::Type_t::f32>(shapes[i][1], arg_inputs[i])}));
EXPECT_EQ(result->get_shape(), (shapes[i][2]));
ASSERT_EQ(read_vector<int32_t>(result), expected_result[i]);
ASSERT_EQ(read_vector<float>(result), expected_result[i]);
}
}