Fix compilation error in pyopenvino (#9150)

This commit is contained in:
Mateusz Tabaka 2021-12-10 12:01:44 +01:00 committed by GitHub
parent 20a5afa84e
commit 0c68574aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 8 deletions

View File

@ -70,10 +70,10 @@ PYBIND11_MODULE(pyopenvino, m) {
m.def("set_batch", &ov::set_batch);
m.def(
"set_batch",
[](const std::shared_ptr<ov::Function>& f, int64_t value) {
return ov::set_batch(f, ov::Dimension(value));
[](const std::shared_ptr<ov::Model>& model, int64_t value) {
return ov::set_batch(model, ov::Dimension(value));
},
py::arg("function"),
py::arg("model"),
py::arg("batch_size") = -1);
regclass_graph_PyRTMap(m);

View File

@ -277,7 +277,7 @@ def test_get_batch():
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
add = ops.add(param1, param2)
func = Function(add, [param1, param2], "TestFunction")
func = Model(add, [param1, param2], "TestFunction")
param = func.get_parameters()[0]
param.set_layout(Layout("NC"))
assert get_batch(func) == 2
@ -289,7 +289,7 @@ def test_get_batch_CHWN():
param3 = ops.parameter(Shape([3, 1, 3, 4]), dtype=np.float32, name="data3")
add = ops.add(param1, param2)
add2 = ops.add(add, param3)
func = Function(add2, [param1, param2, param3], "TestFunction")
func = Model(add2, [param1, param2, param3], "TestFunction")
param = func.get_parameters()[0]
param.set_layout(Layout("CHWN"))
assert get_batch(func) == 4
@ -299,7 +299,7 @@ def test_set_batch_dimension():
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
add = ops.add(param1, param2)
func = Function(add, [param1, param2], "TestFunction")
func = Model(add, [param1, param2], "TestFunction")
func_param1 = func.get_parameters()[0]
func_param2 = func.get_parameters()[1]
# batch == 2
@ -318,7 +318,7 @@ def test_set_batch_int():
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
add = ops.add(param1, param2)
func = Function(add, [param1, param2], "TestFunction")
func = Model(add, [param1, param2], "TestFunction")
func_param1 = func.get_parameters()[0]
func_param2 = func.get_parameters()[1]
# batch == 2
@ -337,7 +337,7 @@ def test_set_batch_default_batch_size():
param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1")
param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2")
add = ops.add(param1, param2)
func = Function(add, [param1, param2], "TestFunction")
func = Model(add, [param1, param2], "TestFunction")
func_param1 = func.get_parameters()[0]
func_param1.set_layout(Layout("NC"))
set_batch(func)