[core]Add IF_TYPE_OF to enable CC for IfTypeOf class (#21240)

* Integrate `IfTypeOf` class with CC
- Add macro to warp class to add support for CC
- Update `ccheader.py` to created PP symbols for template parameter list
- Use new macro in Convert operator

* Correct symbols generation

* Update OV PP macros

* Wrap TestVisitor into namespace{}

* Removed not required macros, update comments

* Update element_visitor.hpp
This commit is contained in:
Pawel Raasz 2023-11-30 08:28:45 +01:00 committed by GitHub
parent 8ee8f4e112
commit 51b5bc5ec4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 15 deletions

View File

@ -20,9 +20,7 @@ from glob import glob
from pathlib import Path
from abc import ABC, abstractmethod
Domain = ['SIMPLE_',
'SWITCH_',
'FACTORY_']
Domain = ["SIMPLE_", "SWITCH_", "FACTORY_", "TYPE_LIST_"]
FILE_HEADER = "#pragma once\n\n"
FILE_FOOTER = "\n"
@ -36,6 +34,7 @@ class IScope(ABC):
def generate(self, f, module):
pass
class Scope(IScope):
def __init__(self, name):
self.name = name
@ -72,6 +71,24 @@ class Factory(IScope):
if r:
f.write(ENABLED_FACTORY_INSTANCE_FMT % (module, r))
class TypeList(IScope):
et_sep = ", "
et_namespace = "::ov::element::"
def __init__(self, name):
self.name = name
self.types = set()
def append(self, type):
self.types.add(type)
def generate(self, f, module):
type_list = self.et_sep.join((self.et_namespace + t for t in self.types))
f.write(f"#define {module}_enabled_{self.name} 1\n")
f.write(f"#define {module}_{self.name} {type_list}\n")
class Module:
def __init__(self, name):
self.name = name
@ -98,6 +115,10 @@ class Module:
if self.scopes:
f.write("\n")
def type_list(self, scope_name):
return self.scopes.setdefault(scope_name, TypeList(scope_name))
class Stat:
def __init__(self, files):
self.modules = {}
@ -135,6 +156,17 @@ class Stat:
for cre in list(filter(lambda row: len(row) > 1 and row[1] == 'CREATE', factories)):
self.module(cre[0]).factory(cre[2]).create(cre[3])
# Type list generator filter, returns tuple of (domain, (region, type))
type_list_filter = (
(row[0], row[1].strip().split("$"))
for row in rows
if len(row) > 1 and row[0].startswith(Domain[3])
)
for domain, (region, type) in type_list_filter:
module = self.module(domain)
module.type_list(region).append(type)
def generate(self, out):
with open(str(out), 'w') as f:
f.write(FILE_HEADER)

View File

@ -13,7 +13,7 @@
#define OV_PP_TOSTRING(...) OV_PP_TOSTRING_(__VA_ARGS__)
#define OV_PP_TOSTRING_(...) #__VA_ARGS__
#define OV_PP_EXPAND(X) X
#define OV_PP_EXPAND(...) __VA_ARGS__
#define OV_PP_NARG(...) OV_PP_EXPAND(OV_PP_NARG_(__VA_ARGS__, OV_PP_RSEQ_N()))
#define OV_PP_NARG_(...) OV_PP_EXPAND(OV_PP_ARG_N(__VA_ARGS__))
@ -48,3 +48,6 @@
// Return second argument from possible sequences {1, 0}, {0, 1, 0}
#define OV_PP_IS_ENABLED2(arg1_or_junk) OV_PP_SECOND_ARG(arg1_or_junk 1, 0)
// Ignores inputs
#define OV_PP_IGNORE(...)

View File

@ -8,10 +8,17 @@
#include "openvino/core/except.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/itt.hpp"
namespace ov {
namespace element {
namespace itt {
namespace domains {
OV_ITT_DOMAIN(ov_eval);
} // namespace domains
} // namespace itt
OV_ITT_DOMAIN(OV_PP_CAT(TYPE_LIST_, ov_eval));
namespace element {
/**
* @brief Primary template defines suppoted element types.
*
@ -44,6 +51,13 @@ struct IfTypeOf<> {
static auto apply(Type_t et, Args&&... args) -> typename Visitor::result_type {
return Visitor::visit();
}
#if defined(SELECTIVE_BUILD_ANALYZER)
template <class Visitor, class... Args>
static auto apply(const std::string& region, Type_t et, Args&&... args) -> typename Visitor::result_type {
return Visitor::visit();
}
#endif
};
/**
@ -71,6 +85,20 @@ struct IfTypeOf<ET, Others...> {
return (et == ET) ? Visitor::template visit<ET>(std::forward<Args>(args)...)
: IfTypeOf<Others...>::template apply<Visitor>(et, std::forward<Args>(args)...);
}
#if defined(SELECTIVE_BUILD_ANALYZER)
template <class Visitor, class... Args>
static auto apply(const std::string& region, Type_t et, Args&&... args) -> typename Visitor::result_type {
return (et == ET && is_cc_enabled(region))
? Visitor::template visit<ET>(std::forward<Args>(args)...)
: IfTypeOf<Others...>::template apply<Visitor>(region, et, std::forward<Args>(args)...);
}
static bool is_cc_enabled(const std::string& region) {
OV_ITT_SCOPED_TASK(OV_PP_CAT(TYPE_LIST_, ov_eval), region + "$" + Type(ET).to_string());
return true;
}
#endif
};
/**
@ -120,3 +148,37 @@ private:
};
} // namespace element
} // namespace ov
// Return ov::elements as parameter list e.g. OV_PP_ET_LIST(f16, i32) -> f16, i32
#define OV_PP_ET_LIST(...) OV_PP_EXPAND(__VA_ARGS__)
// Helpers to implement ignore or expand if symbol exists
#define OV_PP_ET_LIST_OR_EMPTY_0(...) OV_PP_IGNORE(__VA_ARGS__)
#define OV_PP_ET_LIST_OR_EMPTY_1(...) OV_PP_EXPAND(__VA_ARGS__)
// Check if ET list defined and use it for `IfTypeOf` class or make empty list
#define OV_PP_ET_LIST_OR_EMPTY(region) \
OV_PP_EXPAND(OV_PP_CAT(OV_PP_ET_LIST_OR_EMPTY_, OV_PP_IS_ENABLED(OV_PP_CAT(TYPE_LIST_ov_eval_enabled_, region)))( \
OV_PP_CAT(TYPE_LIST_ov_eval_, region)))
/**
* @brief Use this macro wrapper for ov::element::IfTypeOf class to integrate it with
* OpenVINO conditional compilation feature.
*
* @param region Region name for ITT which will be combined with TYPE_LIST_ prefix.
* @param types List ov::element IfTypeOf class e.g. OV_PP_ET_LIST(f16, i8) to pack as one paramater.
* @param visitor Class name of visitor which will be used by IfTypeOf<types>::visit(_VA_ARGS_) function.
* @param ... List of parameters must match parameter list of `visit` function.
*
* @return Value returned by `visit` function
*/
#if defined(SELECTIVE_BUILD_ANALYZER)
# define IF_TYPE_OF(region, types, visitor, ...) \
::ov::element::IfTypeOf<types>::apply<visitor>(OV_PP_TOSTRING(region), __VA_ARGS__)
#elif defined(SELECTIVE_BUILD)
# define IF_TYPE_OF(region, types, visitor, ...) \
::ov::element::IfTypeOf<OV_PP_ET_LIST_OR_EMPTY(region)>::apply<visitor>(__VA_ARGS__)
#else
# define IF_TYPE_OF(region, types, visitor, ...) ::ov::element::IfTypeOf<types>::apply<visitor>(__VA_ARGS__)
#endif

View File

@ -26,12 +26,14 @@ struct Evaluate : public element::NoAction<bool> {
template <element::Type_t ET, class TI = fundamental_type_for<ET>>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IfTypeOf<CONVERT_ET_LIST>::apply<EvalByOutputType<is_lp_type(ET)>>(
out.get_element_type(),
reinterpret_cast<const TI*>(arg.data()),
out,
count,
ET);
return IF_TYPE_OF(Convert_out,
CONVERT_ET_LIST,
EvalByOutputType<is_lp_type(ET)>,
out.get_element_type(),
reinterpret_cast<const TI*>(arg.data()),
out,
count,
ET);
}
private:
@ -137,10 +139,13 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
out.set_shape(in_shape);
using namespace ov::element;
return IfTypeOf<CONVERT_ET_LIST>::apply<convert::Evaluate>(in.get_element_type(),
in,
out,
shape_size(in_shape));
return IF_TYPE_OF(v0_Convert_in_et,
CONVERT_ET_LIST,
convert::Evaluate,
in.get_element_type(),
in,
out,
shape_size(in_shape));
} else {
return false;
}

View File

@ -18,6 +18,7 @@
#define SELECTIVE_BUILD_ANALYZER
#include "element_visitor.hpp"
#include "itt.hpp"
using namespace std;
@ -51,6 +52,23 @@ TEST(conditional_compilation, collect_ops_in_opset) {
#undef ov_opset_test_opset1_Abs
}
namespace {
struct TestVisitor : public ov::element::NoAction<bool> {
using ov::element::NoAction<bool>::visit;
template <ov::element::Type_t ET>
static result_type visit(int x) {
return true;
}
};
} // namespace
TEST(conditional_compilation, IF_TYPE_OF_collect_action_for_supported_element) {
using namespace ov::element;
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
EXPECT_TRUE(result);
}
#undef SELECTIVE_BUILD_ANALYZER
#ifdef SELECTIVE_BUILD_ANALYZER_ON

View File

@ -16,6 +16,7 @@
# undef SELECTIVE_BUILD
#endif
#include "element_visitor.hpp"
#include "itt.hpp"
using namespace std;
@ -45,6 +46,23 @@ TEST(conditional_compilation, all_ops_enabled_in_opset) {
EXPECT_NE(opset.create_insensitive("Constant"), nullptr);
}
namespace {
struct TestVisitor : public ov::element::NoAction<bool> {
using ov::element::NoAction<bool>::visit;
template <ov::element::Type_t ET>
static result_type visit(int x) {
return true;
}
};
} // namespace
TEST(conditional_compilation, IF_TYPE_OF_action_for_supported_element) {
using namespace ov::element;
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
EXPECT_TRUE(result);
}
#ifdef SELECTIVE_BUILD_ANALYZER_ON
# define SELECTIVE_BUILD_ANALYZER
#elif defined(SELECTIVE_BUILD_ON)

View File

@ -18,6 +18,7 @@
#define SELECTIVE_BUILD
#include "element_visitor.hpp"
#include "itt.hpp"
using namespace std;
@ -49,6 +50,45 @@ TEST(conditional_compilation, disabled_Constant_in_opset) {
#undef ov_opset_test_opset3_Abs
}
namespace {
struct TestVisitor : public ov::element::NoAction<bool> {
using ov::element::NoAction<bool>::visit;
template <ov::element::Type_t ET>
static result_type visit(int x) {
return true;
}
};
} // namespace
TEST(conditional_compilation, IF_TYPE_OF_element_type_on_cc_list) {
#define TYPE_LIST_ov_eval_enabled_test_1 1
#define TYPE_LIST_ov_eval_test_1 ::ov::element::f32, ::ov::element::u64
using namespace ov::element;
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
EXPECT_TRUE(result);
#undef TYPE_LIST_ov_eval_enabled_test_1
#undef TYPE_LIST_ov_eval_test_1
}
TEST(conditional_compilation, IF_TYPE_OF_element_type_not_on_cc_list) {
#define TYPE_LIST_ov_eval_enabled_test_1 1
#define TYPE_LIST_ov_eval_test_1 f16
using namespace ov::element;
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
EXPECT_FALSE(result);
#undef TYPE_LIST_ov_eval_enabled_test_1
#undef TYPE_LIST_ov_eval_test_1
}
TEST(conditional_compilation, IF_TYPE_OF_no_element_list) {
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
EXPECT_FALSE(result);
}
#undef SELECTIVE_BUILD
#ifdef SELECTIVE_BUILD_ANALYZER_ON

View File

@ -11,6 +11,7 @@
using namespace testing;
using namespace ov::element;
namespace {
struct TestVisitor : public ov::element::NotSupported<bool> {
using ov::element::NotSupported<bool>::visit;
@ -46,6 +47,7 @@ struct TestVisitorVoidReturn : public ov::element::NoAction<void> {
};
int TestVisitorVoidReturn::test_value;
} // namespace
class IfTypeOfTest : public Test {
protected: