Fix Assign, Matmul ops for correct work with keep_shape_ops feature for Kaldi models (#1885)
* fix assign and matmul ops * fix unit test * static shapes by default for kaldi
This commit is contained in:
@@ -61,11 +61,19 @@ void op::v3::Assign::validate_and_infer_types()
|
||||
"Variables identifiers are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_shape == variable_info.data_shape,
|
||||
"Variables output shapes are inconsistent.");
|
||||
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
if (output_shape.is_static() && variable_info.data_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_shape == variable_info.data_shape,
|
||||
"Variables output shapes are inconsistent.");
|
||||
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, arg_t, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
|
||||
@@ -69,6 +69,10 @@ void op::MatMul::pre_validate_and_infer_types()
|
||||
Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
|
||||
set_output_type(0, result_et, PartialShape::dynamic(max_rank));
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, result_et, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector op::MatMul::decompose_op() const
|
||||
|
||||
Reference in New Issue
Block a user