Fix broken op create test in python api
This commit is contained in:
@@ -117,7 +117,7 @@ def test_ctc_greedy_decoder_seq_len(fp_dtype, int_dtype, int_ci, int_sl, merge_r
|
|||||||
assert list(node.get_output_shape(0)) == expected_shape
|
assert list(node.get_output_shape(0)) == expected_shape
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", np_types)
|
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
||||||
def test_deformable_convolution(dtype):
|
def test_deformable_convolution(dtype):
|
||||||
strides = np.array([1, 1])
|
strides = np.array([1, 1])
|
||||||
pads_begin = np.array([0, 0])
|
pads_begin = np.array([0, 0])
|
||||||
@@ -125,7 +125,7 @@ def test_deformable_convolution(dtype):
|
|||||||
dilations = np.array([1, 1])
|
dilations = np.array([1, 1])
|
||||||
|
|
||||||
input0_shape = [1, 1, 9, 9]
|
input0_shape = [1, 1, 9, 9]
|
||||||
input1_shape = [1, 1, 9, 9]
|
input1_shape = [1, 18, 9, 9]
|
||||||
input2_shape = [1, 1, 3, 3]
|
input2_shape = [1, 1, 3, 3]
|
||||||
expected_shape = [1, 1, 7, 7]
|
expected_shape = [1, 1, 7, 7]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user