Fix Assign, Matmul ops for correct work with keep_shape_ops feature for Kaldi models (#1885)
* fix assign and matmul ops * fix unit test * static shapes by default for kaldi
This commit is contained in:
parent
2acf3f27e1
commit
c74643f6b6
@ -225,7 +225,7 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
from mo.front.mxnet.register_custom_ops import get_front_classes
|
||||
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
|
||||
elif is_kaldi:
|
||||
argv.static_shape = True # ticket #36794
|
||||
argv.static_shape = True
|
||||
from mo.front.kaldi.register_custom_ops import get_front_classes
|
||||
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
|
||||
elif is_onnx:
|
||||
|
@ -61,11 +61,19 @@ void op::v3::Assign::validate_and_infer_types()
|
||||
"Variables identifiers are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_shape == variable_info.data_shape,
|
||||
"Variables output shapes are inconsistent.");
|
||||
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
if (output_shape.is_static() && variable_info.data_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_shape == variable_info.data_shape,
|
||||
"Variables output shapes are inconsistent.");
|
||||
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, arg_t, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
|
@ -69,6 +69,10 @@ void op::MatMul::pre_validate_and_infer_types()
|
||||
Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
|
||||
set_output_type(0, result_et, PartialShape::dynamic(max_rank));
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, result_et, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector op::MatMul::decompose_op() const
|
||||
|
@ -68,18 +68,18 @@ TEST(op_eval, matmul_dynamic_0_elem_arg)
|
||||
std::vector<std::vector<Shape>> shapes{{Shape{2, 0}, Shape{0, 2}, Shape{2, 2}},
|
||||
{Shape{0, 2}, Shape{2, 0}, Shape{0, 0}}};
|
||||
|
||||
std::vector<std::vector<int32_t>> arg_inputs{{}, {}};
|
||||
std::vector<std::vector<int32_t>> expected_result{{0, 0, 0, 0}, {}};
|
||||
std::vector<std::vector<float>> arg_inputs{{}, {}};
|
||||
std::vector<std::vector<float>> expected_result{{0, 0, 0, 0}, {}};
|
||||
|
||||
for (size_t i = 0; i < arg_inputs.size(); i++)
|
||||
{
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(
|
||||
fun->evaluate({result},
|
||||
{make_host_tensor<element::Type_t::i32>(shapes[i][0], arg_inputs[i]),
|
||||
make_host_tensor<element::Type_t::i32>(shapes[i][1], arg_inputs[i])}));
|
||||
{make_host_tensor<element::Type_t::f32>(shapes[i][0], arg_inputs[i]),
|
||||
make_host_tensor<element::Type_t::f32>(shapes[i][1], arg_inputs[i])}));
|
||||
EXPECT_EQ(result->get_shape(), (shapes[i][2]));
|
||||
ASSERT_EQ(read_vector<int32_t>(result), expected_result[i]);
|
||||
ASSERT_EQ(read_vector<float>(result), expected_result[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user