[OV20] ov::Function - get/set batch size (#8955)
* Initial version (no tests) * Added tests * Fix centos * Applied review comments * Renamed 'ov::util::get_batch_size' to 'ov::pass::get_batch'. For set_batch_size update is the same * Changed to ov::get_batch and ov::set_batch
This commit is contained in:
parent
92bdddbfc9
commit
4abaef6702
@ -320,9 +320,8 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
// -------- Step 4. Reshape a model --------
|
||||
// Setting batch size using image count
|
||||
const size_t batch_size = imagesData.size();
|
||||
input_shape[layout::batch_idx(tensor_layout)] = batch_size;
|
||||
model->reshape({{input.get_any_name(), input_shape}});
|
||||
const auto batch_size = static_cast<int64_t>(imagesData.size());
|
||||
ov::set_batch(model, batch_size);
|
||||
slog::info << "Batch size is " << std::to_string(batch_size) << slog::endl;
|
||||
|
||||
const auto outputShape = model->output().get_shape();
|
||||
|
@ -352,4 +352,41 @@ public:
|
||||
OPENVINO_RTTI("AttributeAdapter<std::shared_ptr<Function>");
|
||||
BWDCMP_RTTI_DECLARATION;
|
||||
};
|
||||
|
||||
/// \brief Helper method to get associated batch size for a Function
|
||||
/// \details Checks layout of each parameter in a Function and extracts value for N (B) dimension. All values are then
|
||||
/// merged and returned
|
||||
///
|
||||
/// \throws ::ov::AssertFailure with details in case of error. Possible errors are:
|
||||
/// * There is no parameter with layout set. Function shall have at least one parameter with layout with 'N' dimension.
|
||||
/// Recommended fix is to use `Parameter::set_layout` API, e.g.
|
||||
/// `function->get_parameters()[some_index]->set_layout("NCHW");`
|
||||
/// * Several parameters have conflicting N dimension, e.g. param1 NCHW{1,3,224,224} and param2 NCHW{2,3,224,224}. This
|
||||
/// is ambiguous, most probably first dimension is incorrectly marked as 'batch' (N) in some layout. User shall
|
||||
///// fix it before using of 'get_batch' (in example above correct layout for param2 from 'NCHW' to 'CHWN')
|
||||
///
|
||||
/// \param f function where to look for a batch_size value
|
||||
/// \return Dimension representing current batch size. Can represent a number or be a dynamic
|
||||
OPENVINO_API ov::Dimension get_batch(const std::shared_ptr<const ov::Function>& f);
|
||||
|
||||
/// \brief Helper method to set batch size to a Function
|
||||
///
|
||||
/// \details Checks layout of each parameter in a Function and sets value for N (B) dimension. Then performs validation
|
||||
/// and type propagation
|
||||
///
|
||||
/// \throws ::ov::AssertFailure with details in case of error. Possible errors are:
|
||||
/// * There is no parameter with N dimension in layout. Function shall have at least one parameter with layout with 'N'
|
||||
/// dimension. Recommended fix is to use `Parameter::set_layout` API, e.g.
|
||||
/// `function->get_parameters()[some_index]->set_layout("NCHW");`
|
||||
/// * Several parameters have conflicting N dimension, e.g. param1 NCHW{1,3,224,224} and param2 NCHW{3,224,224,1}. This
|
||||
/// is ambiguous (1 != 3), most probably some dimension is incorrectly marked as 'batch' (N) in some layout. User shall
|
||||
/// fix it before using of 'set_batch' (in example above correct layout for param2 from 'NCHW' to 'CHWN')
|
||||
/// * Validation fails after setting batch_size. Function becomes in inconsistent state after new batch size value is
|
||||
/// applied. Possible reason could be that layout was not set for some parameters, or batch size can't be applied to
|
||||
/// model at all
|
||||
///
|
||||
/// \param f function where to set batch_size value
|
||||
/// \param batch_size Batch size value. For dynamic batch size, Dimension::dynamic() can be passed.
|
||||
OPENVINO_API void set_batch(const std::shared_ptr<ov::Function>& f, ov::Dimension batch_size);
|
||||
|
||||
} // namespace ov
|
||||
|
@ -911,3 +911,122 @@ ov::Output<ov::Node> ov::Function::add_output(const ov::Output<ov::Node>& port)
|
||||
add_results({result});
|
||||
return result->output(0);
|
||||
}
|
||||
|
||||
namespace bs_util {
|
||||
static int64_t get_batch(const ov::Layout& layout, const ov::PartialShape& shape) {
|
||||
auto batch_idx = ov::layout::batch_idx(layout);
|
||||
if (batch_idx < 0) {
|
||||
batch_idx += static_cast<int64_t>(shape.rank().get_length());
|
||||
}
|
||||
return batch_idx;
|
||||
}
|
||||
|
||||
static void dump_parameter(std::ostream& stream, const std::shared_ptr<const ov::Function>& f, size_t index) {
|
||||
const auto& p = f->get_parameters()[index];
|
||||
const auto& node = f->input(index);
|
||||
stream << index << ": { ";
|
||||
if (!node.get_tensor().get_names().empty()) {
|
||||
stream << "name='" << node.get_tensor().get_any_name() << "', ";
|
||||
}
|
||||
stream << "shape=" << node.get_partial_shape();
|
||||
if (node.get_partial_shape().rank().is_static()) {
|
||||
stream << ", layout=" << p->get_layout().to_string();
|
||||
if (!ov::layout::has_batch(p->get_layout())) {
|
||||
stream << ", no batch specified";
|
||||
} else {
|
||||
stream << ", batch="
|
||||
<< node.get_partial_shape()[bs_util::get_batch(p->get_layout(), node.get_partial_shape())];
|
||||
}
|
||||
stream << " }" << std::endl;
|
||||
}
|
||||
}
|
||||
} // namespace bs_util
|
||||
|
||||
ov::Dimension ov::get_batch(const std::shared_ptr<const ov::Function>& f) {
|
||||
bool batch_initialized = false;
|
||||
auto batch_size = ov::Dimension::dynamic();
|
||||
std::vector<size_t> merged_indexes;
|
||||
merged_indexes.reserve(f->inputs().size());
|
||||
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
|
||||
const auto& param = f->get_parameters()[i];
|
||||
const auto& layout = param->get_layout();
|
||||
if (!ov::layout::has_batch(layout))
|
||||
continue;
|
||||
const auto& pshape = param->get_partial_shape();
|
||||
if (pshape.rank().is_dynamic()) {
|
||||
continue; // Parameter with fully dynamic rank can't conflict
|
||||
}
|
||||
auto batch_idx = bs_util::get_batch(layout, pshape);
|
||||
if (!Dimension::merge(batch_size, batch_size, pshape[batch_idx])) {
|
||||
merged_indexes.push_back(i);
|
||||
// Not all dimensions can be merged
|
||||
std::stringstream stream;
|
||||
stream << "Get original batch size fails due to conflicting batch values for inputs:" << std::endl;
|
||||
for (size_t j = 0; j < merged_indexes.size(); ++j) {
|
||||
bs_util::dump_parameter(stream, f, merged_indexes[j]);
|
||||
}
|
||||
stream << "---" << std::endl;
|
||||
stream << "Please ensure that N(Batch) dimension is set correctly for listed parameters";
|
||||
OPENVINO_ASSERT(false, stream.str());
|
||||
} else {
|
||||
merged_indexes.push_back(i);
|
||||
}
|
||||
batch_initialized = true;
|
||||
}
|
||||
if (!batch_initialized) {
|
||||
// Create graceful message to set layout for some parameters
|
||||
std::stringstream stream;
|
||||
stream << "Get original batch size fails due to batch is not set in any layout for any input. ";
|
||||
stream << "Available inputs:" << std::endl;
|
||||
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
|
||||
bs_util::dump_parameter(stream, f, i);
|
||||
}
|
||||
stream << "---" << std::endl;
|
||||
stream << "Please use 'set_layout' API to set layout with batch dimension, e.g. "
|
||||
"`Function->get_parameters()[index]->set_layout(\"NCHW\");`";
|
||||
|
||||
OPENVINO_ASSERT(false, stream.str());
|
||||
}
|
||||
return batch_size;
|
||||
}
|
||||
|
||||
void ov::set_batch(const std::shared_ptr<ov::Function>& f, ov::Dimension batch_size) {
|
||||
get_batch(f); // Ensure that function's batch size is valid and can be changed
|
||||
std::map<ov::Output<ov::Node>, ov::PartialShape> new_shapes_map;
|
||||
// Now batch size can be set for all needed parameters
|
||||
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
|
||||
const auto& param = f->get_parameters()[i];
|
||||
const auto& layout = param->get_layout();
|
||||
if (!ov::layout::has_batch(layout))
|
||||
continue;
|
||||
const auto& pshape = param->get_partial_shape();
|
||||
if (pshape.rank().is_dynamic()) {
|
||||
continue; // Parameter with fully dynamic rank can be left as is
|
||||
}
|
||||
auto batch_idx = bs_util::get_batch(layout, pshape);
|
||||
auto new_shape = param->get_partial_shape();
|
||||
new_shape[batch_idx] = batch_size;
|
||||
new_shapes_map[f->input(i)] = new_shape;
|
||||
}
|
||||
try {
|
||||
f->reshape(new_shapes_map);
|
||||
} catch (const std::exception& e) {
|
||||
std::stringstream stream;
|
||||
stream << "Failed to set batch size to " << batch_size << ". Possible reasons are:" << std::endl;
|
||||
stream << " 1) Ensure that all inputs have valid layout set with batch dimension" << std::endl;
|
||||
stream << " 2) Check model's documentation if batch size can be set to it at all" << std::endl;
|
||||
stream << "Available inputs:" << std::endl;
|
||||
for (size_t i = 0; i < f->get_parameters().size(); ++i) {
|
||||
bs_util::dump_parameter(stream, f, i);
|
||||
if (new_shapes_map.count(f->input(i))) {
|
||||
stream << i << ": Tried reshape " << f->input(i).get_partial_shape() << " to "
|
||||
<< new_shapes_map[f->input(i)] << std::endl;
|
||||
} else {
|
||||
stream << i << ": No reshape has been applied" << std::endl;
|
||||
}
|
||||
}
|
||||
stream << "---" << std::endl;
|
||||
stream << "Original error message is: " << e.what();
|
||||
OPENVINO_ASSERT(false, stream.str());
|
||||
}
|
||||
}
|
@ -1235,3 +1235,204 @@ TEST(function, topological_sort_caching_shared_nodes) {
|
||||
ASSERT_FALSE(f1_shared_info->get_use_topological_cache());
|
||||
ASSERT_FALSE(f2_shared_info->get_use_topological_cache());
|
||||
}
|
||||
|
||||
namespace bs_utils {
|
||||
static std::shared_ptr<ov::Function> create_n_inputs(ov::element::Type type,
|
||||
const std::vector<ov::PartialShape>& shapes,
|
||||
const std::vector<ov::Layout>& layouts) {
|
||||
ov::ResultVector res;
|
||||
ov::ParameterVector params;
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
auto index_str = std::to_string(i);
|
||||
auto data1 = std::make_shared<ov::opset8::Parameter>(type, shapes[i]);
|
||||
data1->set_layout(layouts[i]);
|
||||
data1->set_friendly_name("input" + index_str);
|
||||
data1->get_output_tensor(0).set_names({"tensor_input" + index_str});
|
||||
auto op1 = std::make_shared<ov::opset8::Relu>(data1);
|
||||
op1->set_friendly_name("Relu" + index_str);
|
||||
auto res1 = std::make_shared<ov::opset8::Result>(op1);
|
||||
res1->set_friendly_name("Result" + index_str);
|
||||
res1->get_output_tensor(0).set_names({"tensor_output" + index_str});
|
||||
params.push_back(data1);
|
||||
res.push_back(res1);
|
||||
}
|
||||
auto f = std::make_shared<ov::Function>(res, params);
|
||||
f->validate_nodes_and_infer_types();
|
||||
return f;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ov::Function> create_add(ov::element::Type type,
|
||||
const ov::PartialShape& shape,
|
||||
const ov::Layout& layout1,
|
||||
const ov::Layout& layout2) {
|
||||
ov::ParameterVector params;
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
auto index_str = std::to_string(i);
|
||||
auto data1 = std::make_shared<ov::opset8::Parameter>(type, shape);
|
||||
data1->set_friendly_name("input" + index_str);
|
||||
data1->get_output_tensor(0).set_names({"tensor_input" + index_str});
|
||||
params.push_back(data1);
|
||||
}
|
||||
params[0]->set_layout(layout1);
|
||||
params[1]->set_layout(layout2);
|
||||
auto op1 = std::make_shared<ov::opset8::Add>(params[0],
|
||||
params[1],
|
||||
ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::EXPLICIT));
|
||||
op1->set_friendly_name("Add");
|
||||
auto res1 = std::make_shared<ov::opset8::Result>(op1);
|
||||
res1->get_output_tensor(0).set_names({"tensor_output"});
|
||||
auto f = std::make_shared<ov::Function>(res1, params);
|
||||
f->validate_nodes_and_infer_types();
|
||||
return f;
|
||||
}
|
||||
} // namespace bs_utils
|
||||
|
||||
TEST(function, get_batch_size) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 512, 512, 3}, {1, 3, 224, 224}}, {"NHWC", "NCHW"});
|
||||
|
||||
EXPECT_NO_THROW(ov::get_batch(f));
|
||||
EXPECT_EQ(ov::get_batch(f), 1);
|
||||
}
|
||||
|
||||
TEST(function, get_batch_size_with_conflict) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{ov::PartialShape::dynamic(), {5, 6}, {1, 3, 224, 224}, {3, 1}},
|
||||
{"NCHW", "D...", "NCHW", "N???"});
|
||||
|
||||
// TODO: gtest v.10 limitation. Replace with EXPECT_THAT for gtest >= v1.11
|
||||
try {
|
||||
ov::get_batch(f);
|
||||
FAIL() << "get_batch shall throw";
|
||||
} catch (const ov::Exception& err) {
|
||||
// Verify error message contains conflicting layouts
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("N???").to_string()) != std::string::npos) << err.what();
|
||||
// Verify error message doesn't contain non-conflicting layouts
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("D...").to_string()) == std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find("tensor_input_0") == std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find("tensor_input_1") == std::string::npos) << err.what();
|
||||
} catch (...) {
|
||||
FAIL() << "Expected ov::Exception";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(function, get_batch_size_without_batches) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 3, 224, 224}, {1, 3, 224, 224}}, {"?C...", ov::Layout()});
|
||||
|
||||
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
|
||||
try {
|
||||
ov::get_batch(f);
|
||||
FAIL() << "get_batch shall throw";
|
||||
} catch (const ov::Exception& err) {
|
||||
// Verify error message contains layouts without batches
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("?C...").to_string()) != std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
|
||||
} catch (...) {
|
||||
FAIL() << "Expected ov::Exception";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(function, get_batch_size_without_one_layout) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{{ov::Dimension::dynamic(), 3, 224, 224}, {10, 20}},
|
||||
{"N...", "HW"});
|
||||
EXPECT_EQ(ov::get_batch(f), ov::Dimension::dynamic());
|
||||
}
|
||||
|
||||
TEST(function, get_batch_size_ranges) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{{ov::Dimension(1, 10), 3, 224, 224}, {ov::Dimension(5, 15), 3, 224, 224}},
|
||||
{"NCHW", "NCHW"});
|
||||
EXPECT_EQ(ov::get_batch(f), ov::Dimension(5, 10));
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{{1, 512, 512, 3}, {ov::Dimension::dynamic(), 3, 224, 224}, {1, 5}},
|
||||
{"NHWC", "NCHW", "??"});
|
||||
EXPECT_NO_THROW(ov::set_batch(f, 4));
|
||||
ov::PartialShape pshape({1, 4, 3, 3});
|
||||
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{4, 512, 512, 3}));
|
||||
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{4, 3, 224, 224}));
|
||||
EXPECT_EQ(f->input("tensor_input2").get_partial_shape(), (ov::PartialShape{1, 5}));
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_ranges) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{{ov::Dimension(1, 10), 3, 224, 224}, {ov::Dimension(5, 15), 3, 224, 224}},
|
||||
{"NCHW", "NCHW"});
|
||||
EXPECT_NO_THROW(ov::set_batch(f, 42));
|
||||
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
|
||||
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_fully_dynamic) {
|
||||
auto f =
|
||||
bs_utils::create_n_inputs(ov::element::f32, {ov::PartialShape::dynamic(), {1, 3, 224, 224}}, {"NCHW", "NCHW"});
|
||||
EXPECT_NO_THROW(ov::set_batch(f, 42));
|
||||
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape::dynamic()));
|
||||
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_dynamic_layout) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32, {{3, 224, 224, 1}, {1, 3, 224, 224}}, {"...N", "NCHW"});
|
||||
EXPECT_NO_THROW(ov::set_batch(f, 42));
|
||||
EXPECT_EQ(f->input("tensor_input0").get_partial_shape(), (ov::PartialShape{3, 224, 224, 42}));
|
||||
EXPECT_EQ(f->input("tensor_input1").get_partial_shape(), (ov::PartialShape{42, 3, 224, 224}));
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_with_conflict) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32,
|
||||
{ov::PartialShape::dynamic(), {5, 6}, {1, 3, 224, 224}, {3, 1}},
|
||||
{"NCHW", "D...", "NCHW", "N???"});
|
||||
|
||||
// TODO: gtest v.10 limitation. Replace with EXPECT_THAT for gtest >= v1.11
|
||||
try {
|
||||
ov::set_batch(f, 12);
|
||||
FAIL() << "set_batch shall throw";
|
||||
} catch (const ov::Exception& err) {
|
||||
// Verify error message contains conflicting layouts
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("N???").to_string()) != std::string::npos) << err.what();
|
||||
// Verify error message doesn't contain non-conflicting layouts
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("D...").to_string()) == std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find("tensor_input_0") == std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find("tensor_input_1") == std::string::npos) << err.what();
|
||||
} catch (...) {
|
||||
FAIL() << "Expected ov::Exception";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_without_batches) {
|
||||
auto f = bs_utils::create_n_inputs(ov::element::f32, {{1, 3, 224, 224}, {1, 3, 224, 224}}, {"?C...", ov::Layout()});
|
||||
|
||||
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
|
||||
try {
|
||||
ov::set_batch(f, 42);
|
||||
FAIL() << "set_batch shall throw";
|
||||
} catch (const ov::Exception& err) {
|
||||
// Verify error message contains layouts without batches
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("?C...").to_string()) != std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
|
||||
} catch (...) {
|
||||
FAIL() << "Expected ov::Exception";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(function, set_batch_size_validation_throw) {
|
||||
auto f = bs_utils::create_add(ov::element::f32, {1, 3, 224, 224}, "NCHW", ov::Layout());
|
||||
|
||||
// TODO: replace with EXPECT_THAT after upgrade gtest to v1.11
|
||||
try {
|
||||
ov::set_batch(f, 42);
|
||||
FAIL() << "set_batch shall throw";
|
||||
} catch (const ov::Exception& err) {
|
||||
// Verify error message contains possible reasons
|
||||
EXPECT_TRUE(std::string(err.what()).find("Possible reasons") != std::string::npos) << err.what();
|
||||
// Verify error message contains all layouts
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout("NCHW").to_string()) != std::string::npos) << err.what();
|
||||
EXPECT_TRUE(std::string(err.what()).find(ov::Layout().to_string()) != std::string::npos) << err.what();
|
||||
} catch (...) {
|
||||
FAIL() << "Expected ov::Exception";
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user