[OV20] ov::Function - get/set batch size (#8955)

* Initial version (no tests)

* Added tests

* Fix centos

* Applied review comments

* Renamed 'ov::util::get_batch_size' to 'ov::pass::get_batch'. For set_batch_size update is the same

* Changed to ov::get_batch and ov::set_batch
This commit is contained in:
Mikhail Nosov 2021-12-07 16:40:48 +03:00 committed by GitHub
parent 92bdddbfc9
commit 4abaef6702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 359 additions and 3 deletions

View File

@ -320,9 +320,8 @@ int main(int argc, char* argv[]) {
// -------- Step 4. Reshape a model --------
// Setting batch size using image count
const size_t batch_size = imagesData.size();
input_shape[layout::batch_idx(tensor_layout)] = batch_size;
model->reshape({{input.get_any_name(), input_shape}});
const auto batch_size = static_cast<int64_t>(imagesData.size());
ov::set_batch(model, batch_size);
slog::info << "Batch size is " << std::to_string(batch_size) << slog::endl;
const auto outputShape = model->output().get_shape();

View File

@ -352,4 +352,41 @@ public:
OPENVINO_RTTI("AttributeAdapter<std::shared_ptr<Function>");
BWDCMP_RTTI_DECLARATION;
};
/// \brief Helper method to get associated batch size for a Function
/// \details Checks layout of each parameter in a Function and extracts value for N (B) dimension. All values are then
/// merged and returned
///
/// \throws ::ov::AssertFailure with details in case of error. Possible errors are:
/// * There is no parameter with layout set. Function shall have at least one parameter with layout with 'N' dimension.
/// Recommended fix is to use `Parameter::set_layout` API, e.g.
/// `function->get_parameters()[some_index]->set_layout("NCHW");`
/// * Several parameters have conflicting N dimension, e.g. param1 NCHW{1,3,224,224} and param2 NCHW{2,3,224,224}. This
/// is ambiguous, most probably first dimension is incorrectly marked as 'batch' (N) in some layout. User shall
///// fix it before using of 'get_batch' (in example above correct layout for param2 from 'NCHW' to 'CHWN')
///
/// \param f function where to look for a batch_size value
/// \return Dimension representing current batch size. Can represent a number or be a dynamic
OPENVINO_API ov::Dimension get_batch(const std::shared_ptr<const ov::Function>& f);
/// \brief Helper method to set batch size to a Function
///
/// \details Checks layout of each parameter in a Function and sets value for N (B) dimension. Then performs validation
/// and type propagation
///
/// \throws ::ov::AssertFailure with details in case of error. Possible errors are:
/// * There is no parameter with N dimension in layout. Function shall have at least one parameter with layout with 'N'
/// dimension. Recommended fix is to use `Parameter::set_layout` API, e.g.
/// `function->get_parameters()[some_index]->set_layout("NCHW");`
/// * Several parameters have conflicting N dimension, e.g. param1 NCHW{1,3,224,224} and param2 NCHW{3,224,224,1}. This
/// is ambiguous (1 != 3), most probably some dimension is incorrectly marked as 'batch' (N) in some layout. User shall
/// fix it before using of 'set_batch' (in example above correct layout for param2 from 'NCHW' to 'CHWN')
/// * Validation fails after setting batch_size. Function becomes in inconsistent state after new batch size value is
/// applied. Possible reason could be that layout was not set for some parameters, or batch size can't be applied to
/// model at all
///
/// \param f function where to set batch_size value
/// \param batch_size Batch size value. For dynamic batch size, Dimension::dynamic() can be passed.
OPENVINO_API void set_batch(const std::shared_ptr<ov::Function>& f, ov::Dimension batch_size);
} // namespace ov

View File

@ -911,3 +911,122 @@ ov::Output<ov::Node> ov::Function::add_output(const ov::Output<ov::Node>& port)
add_results({result});
return result->output(0);
}
namespace bs_util {
static int64_t get_batch(const ov::Layout& layout, const ov::PartialShape& shape) {
auto batch_idx = ov::layout::batch_idx(layout);
if (batch_idx < 0) {
batch_idx += static_cast<int64_t>(shape.rank().get_length());
}
return batch_idx;
}
static void dump_parameter(std::ostream& stream, const std::shared_ptr<const ov::Function>& f, size_t index) {
const auto& p = f->get_parameters()[index];
const auto& node = f->input(index);
stream << index << ": { ";
if (!node.get_tensor().get_names().empty()) {
stream << "name='" << node.get_tensor().get_any_name() << "', ";
}
stream << "shape=" << node.get_partial_shape();
if (node.get_partial_shape().rank().is_static()) {
stream << ", layout=" << p->get_layout().to_string();
if (!ov::layout::has_batch(p->get_layout())) {
stream << ", no batch specified";
} else {
stream << ", batch="
<< node.get_partial_shape()[bs_util::get_batch(p->get_layout(), node.get_partial_shape())];
}
stream << " }" << std::endl;
}
}
} // namespace bs_util
ov::Dimension ov::get_batch(const std::shared_ptr<const ov::Function>& f) {
bool batch_initialized = false;
auto batch_size = ov::Dimension::dynamic();
std::vector<size_t> merged_indexes;
merged_indexes.reserve(f->inputs().size());
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
const auto& param = f->get_parameters()[i];
const auto& layout = param->get_layout();
if (!ov::layout::has_batch(layout))
continue;
const auto& pshape = param->get_partial_shape();
if (pshape.rank().is_dynamic()) {
continue; // Parameter with fully dynamic rank can't conflict
}
auto batch_idx = bs_util::get_batch(layout, pshape);
if (!Dimension::merge(batch_size, batch_size, pshape[batch_idx])) {
merged_indexes.push_back(i);
// Not all dimensions can be merged
std::stringstream stream;
stream << "Get original batch size fails due to conflicting batch values for inputs:" << std::endl;
for (size_t j = 0; j < merged_indexes.size(); ++j) {
bs_util::dump_parameter(stream, f, merged_indexes[j]);
}
stream << "---" << std::endl;
stream << "Please ensure that N(Batch) dimension is set correctly for listed parameters";
OPENVINO_ASSERT(false, stream.str());
} else {
merged_indexes.push_back(i);
}
batch_initialized = true;
}
if (!batch_initialized) {
// Create graceful message to set layout for some parameters
std::stringstream stream;
stream << "Get original batch size fails due to batch is not set in any layout for any input. ";
stream << "Available inputs:" << std::endl;
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
bs_util::dump_parameter(stream, f, i);
}
stream << "---" << std::endl;
stream << "Please use 'set_layout' API to set layout with batch dimension, e.g. "
"`Function->get_parameters()[index]->set_layout(\"NCHW\");`";
OPENVINO_ASSERT(false, stream.str());
}
return batch_size;
}
void ov::set_batch(const std::shared_ptr<ov::Function>& f, ov::Dimension batch_size) {
get_batch(f); // Ensure that function's batch size is valid and can be changed
std::map<ov::Output<ov::Node>, ov::PartialShape> new_shapes_map;
// Now batch size can be set for all needed parameters
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
const auto& param = f->get_parameters()[i];
const auto& layout = param->get_layout();
if (!ov::layout::has_batch(layout))
continue;
const auto& pshape = param->get_partial_shape();
if (pshape.rank().is_dynamic()) {
continue; // Parameter with fully dynamic rank can be left as is
}
auto batch_idx = bs_util::get_batch(layout, pshape);
auto new_shape = param->get_partial_shape();
new_shape[batch_idx] = batch_size;
new_shapes_map[f->input(i)] = new_shape;
}
try {
f->reshape(new_shapes_map);
} catch (const std::exception& e) {
std::stringstream stream;
stream << "Failed to set batch size to " << batch_size << ". Possible reasons are:" << std::endl;
stream << " 1) Ensure that all inputs have valid layout set with batch dimension" << std::endl;
stream << " 2) Check model's documentation if batch size can be set to it at all" << std::endl;
stream << "Available inputs:" << std::endl;
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
bs_util::dump_parameter(stream, f, i);
if (new_shapes_map.count(f->input(i))) {
stream << i << ": Tried reshape " << f->input(i).get_partial_shape() << " to "
<< new_shapes_map[f->input(i)] << std::endl;
} else {
stream << i << ": No reshape has been applied" << std::endl;
}
}
stream << "---" << std::endl;
stream << "Original error message is: " << e.what();
OPENVINO_ASSERT(false, stream.str());
}
}

View File

@ -1235,3 +1235,204 @@ TEST(function, topological_sort_caching_shared_nodes) {
ASSERT_FALSE(f1_shared_info->get_use_topological_cache());
ASSERT_FALSE(f2_shared_info->get_use_topological_cache());
}
namespace bs_utils {
static std::shared_ptr<ov::Function> create_n_inputs(ov::element::Type type,
const std::vector<ov::PartialShape>& shapes,
const std::vector<ov::Layout>& layouts) {
ov::ResultVector res;
ov::ParameterVector params;
for (size_t i = 0; i < shapes.size(); i++) {
auto index_str = std::to_string(i);
auto data1 = std::make_shared<ov::opset8::Parameter>(type, shapes[i]);
data1->set_layout(layouts[i]);
data1->set_friendly_name("input" + index_str);
data1->get_output_tensor(0).set_names({"tensor_input" + index_str});
auto op1 = std::make_shared<ov::opset8::Relu>(data1);
op1->set_friendly_name("Relu" + index_str);
auto res1 = std::make_shared<ov::opset8::Result>(op1);
res1->set_friendly_name("Result" + index_str);
res1->get_output_tensor(0).set_names({"tensor_output" + index_str});
params.push_back(data1);
res.push_back(res1);
}
auto f = std::make_shared<ov::Function>(res, params);
f->validate_nodes_and_infer_types();
return f;
}
static std::shared_ptr<ov::Function> create_add(ov::element::Type type,
const ov::PartialShape& shape,
const ov::Layout& layout1,
const ov::Layout& layout2) {
ov::ParameterVector params;
for (size_t i = 0; i < 2; i++) {
auto index_str = std::to_string(i);
auto data1 = std::make_shared<ov::opset8::Parameter>(type, shape);
data1->set_friendly_name("input" + index_str);
data1->get_output_tensor(0).set_names({"tensor_input" + index_str});
params.push_back(data1);
}
params[0]->set_layout(layout1);
params[1]->set_layout(layout2);
auto op1 = std::make_shared<ov::opset8::Add>(params[0],
params[1],
ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::EXPLICIT));
op1->set_friendly_name("Add");
auto res1 = std::make_shared<ov::opset8::Result>(op1);
res1->get_output_tensor(0).set_names({"tensor_output"});
auto f = std::make_shared<ov::Function>(res1, params);
f->validate_nodes_and_infer_types();
return f;
}
} // namespace bs_utils
TEST(function, get_batch_size) {
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 512, 512, 3}, {1, 3, 224, 224}}, {"NHWC", "NCHW"});
EXPECT_NO_THROW(ov::get_batch(f));
EXPECT_EQ(ov::get_batch(f), 1);
}
TEST(function, get_batch_size_with_conflict) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{ov::PartialShape::dynamic(), {5, 6}, {1, 3, 224, 224}, {3, 1}},
{"NCHW", "D...", "NCHW", "N???"});
// TODO: gtest v.10 limitation. Replace with EXPECT_THAT for gtest >= v1.11
try {
ov::get_batch(f);
FAIL() << "get_batch shall throw";
} catch (const ov::Exception& err) {
// Verify error message contains conflicting layouts
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("N???").to_string()) != std::string::npos) << err.what();
// Verify error message doesn't contain non-conflicting layouts
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("D...").to_string()) == std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find("tensor_input_0") == std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find("tensor_input_1") == std::string::npos) << err.what();
} catch (...) {
FAIL() << "Expected ov::Exception";
}
}
TEST(function, get_batch_size_without_batches) {
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 3, 224, 224}, {1, 3, 224, 224}}, {"?C...", ov::Layout()});
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
try {
ov::get_batch(f);
FAIL() << "get_batch shall throw";
} catch (const ov::Exception& err) {
// Verify error message contains layouts without batches
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("?C...").to_string()) != std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
} catch (...) {
FAIL() << "Expected ov::Exception";
}
}
TEST(function, get_batch_size_without_one_layout) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{{ov::Dimension::dynamic(), 3, 224, 224}, {10, 20}},
{"N...", "HW"});
EXPECT_EQ(ov::get_batch(f), ov::Dimension::dynamic());
}
TEST(function, get_batch_size_ranges) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{{ov::Dimension(1, 10), 3, 224, 224}, {ov::Dimension(5, 15), 3, 224, 224}},
{"NCHW", "NCHW"});
EXPECT_EQ(ov::get_batch(f), ov::Dimension(5, 10));
}
TEST(function, set_batch_size) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{{1, 512, 512, 3}, {ov::Dimension::dynamic(), 3, 224, 224}, {1, 5}},
{"NHWC", "NCHW", "??"});
EXPECT_NO_THROW(ov::set_batch(f, 4));
ov::PartialShape pshape({1, 4, 3, 3});
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{4, 512, 512, 3}));
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{4, 3, 224, 224}));
EXPECT_EQ(f->input("tensor_input2").get_partial_shape(), (ov::PartialShape{1, 5}));
}
TEST(function, set_batch_size_ranges) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{{ov::Dimension(1, 10), 3, 224, 224}, {ov::Dimension(5, 15), 3, 224, 224}},
{"NCHW", "NCHW"});
EXPECT_NO_THROW(ov::set_batch(f, 42));
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
}
TEST(function, set_batch_size_fully_dynamic) {
auto f =
bs_utils::create_n_inputs(ov::element::f32, {ov::PartialShape::dynamic(), {1, 3, 224, 224}}, {"NCHW", "NCHW"});
EXPECT_NO_THROW(ov::set_batch(f, 42));
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape::dynamic()));
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
}
TEST(function, set_batch_size_dynamic_layout) {
auto f = bs_utils::create_n_inputs(ov::element::f32, {{3, 224, 224, 1}, {1, 3, 224, 224}}, {"...N", "NCHW"});
EXPECT_NO_THROW(ov::set_batch(f, 42));
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{3, 224, 224, 42}));
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
}
TEST(function, set_batch_size_with_conflict) {
auto f = bs_utils::create_n_inputs(ov::element::f32,
{ov::PartialShape::dynamic(), {5, 6}, {1, 3, 224, 224}, {3, 1}},
{"NCHW", "D...", "NCHW", "N???"});
// TODO: gtest v.10 limitation. Replace with EXPECT_THAT for gtest >= v1.11
try {
ov::set_batch(f, 12);
FAIL() << "set_batch shall throw";
} catch (const ov::Exception& err) {
// Verify error message contains conflicting layouts
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("N???").to_string()) != std::string::npos) << err.what();
// Verify error message doesn't contain non-conflicting layouts
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("D...").to_string()) == std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find("tensor_input_0") == std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find("tensor_input_1") == std::string::npos) << err.what();
} catch (...) {
FAIL() << "Expected ov::Exception";
}
}
TEST(function, set_batch_size_without_batches) {
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 3, 224, 224}, {1, 3, 224, 224}}, {"?C...", ov::Layout()});
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
try {
ov::set_batch(f, 42);
FAIL() << "set_batch shall throw";
} catch (const ov::Exception& err) {
// Verify error message contains layouts without batches
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("?C...").to_string()) != std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
} catch (...) {
FAIL() << "Expected ov::Exception";
}
}
TEST(function, set_batch_size_validation_throw) {
auto f = bs_utils::create_add(ov::element::f32, {1, 3, 224, 224}, "NCHW", ov::Layout());
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
try {
ov::set_batch(f, 42);
FAIL() << "set_batch shall throw";
} catch (const ov::Exception& err) {
// Verify error message contains possible reasons
EXPECT_TRUE(std::string(err.what()).find("Possible reasons") != std::string::npos) << err.what();
// Verify error message contains all layouts
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
} catch (...) {
FAIL() << "Expected ov::Exception";
}
}