Fix compilation error in pyopenvino (#9150)
This commit is contained in:
parent
20a5afa84e
commit
0c68574aa7
@ -70,10 +70,10 @@ PYBIND11_MODULE(pyopenvino, m) {
|
||||
m.def("set_batch", &ov::set_batch);
|
||||
m.def(
|
||||
"set_batch",
|
||||
[](const std::shared_ptr<ov::Function>& f, int64_t value) {
|
||||
return ov::set_batch(f, ov::Dimension(value));
|
||||
[](const std::shared_ptr<ov::Model>& model, int64_t value) {
|
||||
return ov::set_batch(model, ov::Dimension(value));
|
||||
},
|
||||
py::arg("function"),
|
||||
py::arg("model"),
|
||||
py::arg("batch_size") = -1);
|
||||
|
||||
regclass_graph_PyRTMap(m);
|
||||
|
@ -277,7 +277,7 @@ def test_get_batch():
|
||||
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
|
||||
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
|
||||
add = ops.add(param1, param2)
|
||||
func = Function(add, [param1, param2], "TestFunction")
|
||||
func = Model(add, [param1, param2], "TestFunction")
|
||||
param = func.get_parameters()[0]
|
||||
param.set_layout(Layout("NC"))
|
||||
assert get_batch(func) == 2
|
||||
@ -289,7 +289,7 @@ def test_get_batch_CHWN():
|
||||
param3 = ops.parameter(Shape([3, 1, 3, 4]), dtype=np.float32, name="data3")
|
||||
add = ops.add(param1, param2)
|
||||
add2 = ops.add(add, param3)
|
||||
func = Function(add2, [param1, param2, param3], "TestFunction")
|
||||
func = Model(add2, [param1, param2, param3], "TestFunction")
|
||||
param = func.get_parameters()[0]
|
||||
param.set_layout(Layout("CHWN"))
|
||||
assert get_batch(func) == 4
|
||||
@ -299,7 +299,7 @@ def test_set_batch_dimension():
|
||||
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
|
||||
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
|
||||
add = ops.add(param1, param2)
|
||||
func = Function(add, [param1, param2], "TestFunction")
|
||||
func = Model(add, [param1, param2], "TestFunction")
|
||||
func_param1 = func.get_parameters()[0]
|
||||
func_param2 = func.get_parameters()[1]
|
||||
# batch == 2
|
||||
@ -318,7 +318,7 @@ def test_set_batch_int():
|
||||
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
|
||||
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
|
||||
add = ops.add(param1, param2)
|
||||
func = Function(add, [param1, param2], "TestFunction")
|
||||
func = Model(add, [param1, param2], "TestFunction")
|
||||
func_param1 = func.get_parameters()[0]
|
||||
func_param2 = func.get_parameters()[1]
|
||||
# batch == 2
|
||||
@ -337,7 +337,7 @@ def test_set_batch_default_batch_size():
|
||||
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
|
||||
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
|
||||
add = ops.add(param1, param2)
|
||||
func = Function(add, [param1, param2], "TestFunction")
|
||||
func = Model(add, [param1, param2], "TestFunction")
|
||||
func_param1 = func.get_parameters()[0]
|
||||
func_param1.set_layout(Layout("NC"))
|
||||
set_batch(func)
|
||||
|
Loading…
Reference in New Issue
Block a user