Fix Assign, Matmul ops for correct work with keep_shape_ops feature for Kaldi models (#1885)

* fix assign and matmul ops

* fix unit test

* static shapes by default for kaldi
This commit is contained in:
Ivan Tikhonov
2020-08-24 14:11:01 +03:00
committed by GitHub
parent 2acf3f27e1
commit c74643f6b6
4 changed files with 22 additions and 10 deletions

View File

@@ -61,11 +61,19 @@ void op::v3::Assign::validate_and_infer_types()
"Variables identifiers are inconsistent.");
NODE_VALIDATION_CHECK(
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
NODE_VALIDATION_CHECK(this,
output_shape == variable_info.data_shape,
"Variables output shapes are inconsistent.");
set_output_type(0, arg_t, output_shape);
if (output_shape.is_static() && variable_info.data_shape.is_static())
{
NODE_VALIDATION_CHECK(this,
output_shape == variable_info.data_shape,
"Variables output shapes are inconsistent.");
set_output_type(0, arg_t, output_shape);
}
else
{
set_output_type(0, arg_t, PartialShape::dynamic());
}
}
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const

View File

@@ -69,6 +69,10 @@ void op::MatMul::pre_validate_and_infer_types()
Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
set_output_type(0, result_et, PartialShape::dynamic(max_rank));
}
else
{
set_output_type(0, result_et, PartialShape::dynamic());
}
}
OutputVector op::MatMul::decompose_op() const