[core]Add IF_TYPE_OF
to enable CC for IfTypeOf
class (#21240)
* Integrate `IfTypeOf` class with CC - Add macro to warp class to add support for CC - Update `ccheader.py` to created PP symbols for template parameter list - Use new macro in Convert operator * Correct symbols generation * Update OV PP macros * Wrap TestVisitor into namespace{} * Removed not required macros, update comments * Update element_visitor.hpp
This commit is contained in:
parent
8ee8f4e112
commit
51b5bc5ec4
@ -20,9 +20,7 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
Domain = ['SIMPLE_',
|
||||
'SWITCH_',
|
||||
'FACTORY_']
|
||||
Domain = ["SIMPLE_", "SWITCH_", "FACTORY_", "TYPE_LIST_"]
|
||||
|
||||
FILE_HEADER = "#pragma once\n\n"
|
||||
FILE_FOOTER = "\n"
|
||||
@ -36,6 +34,7 @@ class IScope(ABC):
|
||||
def generate(self, f, module):
|
||||
pass
|
||||
|
||||
|
||||
class Scope(IScope):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
@ -72,6 +71,24 @@ class Factory(IScope):
|
||||
if r:
|
||||
f.write(ENABLED_FACTORY_INSTANCE_FMT % (module, r))
|
||||
|
||||
|
||||
class TypeList(IScope):
|
||||
et_sep = ", "
|
||||
et_namespace = "::ov::element::"
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.types = set()
|
||||
|
||||
def append(self, type):
|
||||
self.types.add(type)
|
||||
|
||||
def generate(self, f, module):
|
||||
type_list = self.et_sep.join((self.et_namespace + t for t in self.types))
|
||||
f.write(f"#define {module}_enabled_{self.name} 1\n")
|
||||
f.write(f"#define {module}_{self.name} {type_list}\n")
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
@ -98,6 +115,10 @@ class Module:
|
||||
if self.scopes:
|
||||
f.write("\n")
|
||||
|
||||
def type_list(self, scope_name):
|
||||
return self.scopes.setdefault(scope_name, TypeList(scope_name))
|
||||
|
||||
|
||||
class Stat:
|
||||
def __init__(self, files):
|
||||
self.modules = {}
|
||||
@ -135,6 +156,17 @@ class Stat:
|
||||
for cre in list(filter(lambda row: len(row) > 1 and row[1] == 'CREATE', factories)):
|
||||
self.module(cre[0]).factory(cre[2]).create(cre[3])
|
||||
|
||||
# Type list generator filter, returns tuple of (domain, (region, type))
|
||||
type_list_filter = (
|
||||
(row[0], row[1].strip().split("$"))
|
||||
for row in rows
|
||||
if len(row) > 1 and row[0].startswith(Domain[3])
|
||||
)
|
||||
|
||||
for domain, (region, type) in type_list_filter:
|
||||
module = self.module(domain)
|
||||
module.type_list(region).append(type)
|
||||
|
||||
def generate(self, out):
|
||||
with open(str(out), 'w') as f:
|
||||
f.write(FILE_HEADER)
|
||||
|
@ -13,7 +13,7 @@
|
||||
#define OV_PP_TOSTRING(...) OV_PP_TOSTRING_(__VA_ARGS__)
|
||||
#define OV_PP_TOSTRING_(...) #__VA_ARGS__
|
||||
|
||||
#define OV_PP_EXPAND(X) X
|
||||
#define OV_PP_EXPAND(...) __VA_ARGS__
|
||||
|
||||
#define OV_PP_NARG(...) OV_PP_EXPAND(OV_PP_NARG_(__VA_ARGS__, OV_PP_RSEQ_N()))
|
||||
#define OV_PP_NARG_(...) OV_PP_EXPAND(OV_PP_ARG_N(__VA_ARGS__))
|
||||
@ -48,3 +48,6 @@
|
||||
|
||||
// Return second argument from possible sequences {1, 0}, {0, 1, 0}
|
||||
#define OV_PP_IS_ENABLED2(arg1_or_junk) OV_PP_SECOND_ARG(arg1_or_junk 1, 0)
|
||||
|
||||
// Ignores inputs
|
||||
#define OV_PP_IGNORE(...)
|
||||
|
@ -8,10 +8,17 @@
|
||||
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "openvino/itt.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace element {
|
||||
namespace itt {
|
||||
namespace domains {
|
||||
OV_ITT_DOMAIN(ov_eval);
|
||||
} // namespace domains
|
||||
} // namespace itt
|
||||
|
||||
OV_ITT_DOMAIN(OV_PP_CAT(TYPE_LIST_, ov_eval));
|
||||
namespace element {
|
||||
/**
|
||||
* @brief Primary template defines suppoted element types.
|
||||
*
|
||||
@ -44,6 +51,13 @@ struct IfTypeOf<> {
|
||||
static auto apply(Type_t et, Args&&... args) -> typename Visitor::result_type {
|
||||
return Visitor::visit();
|
||||
}
|
||||
|
||||
#if defined(SELECTIVE_BUILD_ANALYZER)
|
||||
template <class Visitor, class... Args>
|
||||
static auto apply(const std::string& region, Type_t et, Args&&... args) -> typename Visitor::result_type {
|
||||
return Visitor::visit();
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
@ -71,6 +85,20 @@ struct IfTypeOf<ET, Others...> {
|
||||
return (et == ET) ? Visitor::template visit<ET>(std::forward<Args>(args)...)
|
||||
: IfTypeOf<Others...>::template apply<Visitor>(et, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
#if defined(SELECTIVE_BUILD_ANALYZER)
|
||||
template <class Visitor, class... Args>
|
||||
static auto apply(const std::string& region, Type_t et, Args&&... args) -> typename Visitor::result_type {
|
||||
return (et == ET && is_cc_enabled(region))
|
||||
? Visitor::template visit<ET>(std::forward<Args>(args)...)
|
||||
: IfTypeOf<Others...>::template apply<Visitor>(region, et, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static bool is_cc_enabled(const std::string& region) {
|
||||
OV_ITT_SCOPED_TASK(OV_PP_CAT(TYPE_LIST_, ov_eval), region + "$" + Type(ET).to_string());
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
@ -120,3 +148,37 @@ private:
|
||||
};
|
||||
} // namespace element
|
||||
} // namespace ov
|
||||
|
||||
// Return ov::elements as parameter list e.g. OV_PP_ET_LIST(f16, i32) -> f16, i32
|
||||
#define OV_PP_ET_LIST(...) OV_PP_EXPAND(__VA_ARGS__)
|
||||
|
||||
// Helpers to implement ignore or expand if symbol exists
|
||||
#define OV_PP_ET_LIST_OR_EMPTY_0(...) OV_PP_IGNORE(__VA_ARGS__)
|
||||
#define OV_PP_ET_LIST_OR_EMPTY_1(...) OV_PP_EXPAND(__VA_ARGS__)
|
||||
|
||||
// Check if ET list defined and use it for `IfTypeOf` class or make empty list
|
||||
#define OV_PP_ET_LIST_OR_EMPTY(region) \
|
||||
OV_PP_EXPAND(OV_PP_CAT(OV_PP_ET_LIST_OR_EMPTY_, OV_PP_IS_ENABLED(OV_PP_CAT(TYPE_LIST_ov_eval_enabled_, region)))( \
|
||||
OV_PP_CAT(TYPE_LIST_ov_eval_, region)))
|
||||
|
||||
/**
|
||||
* @brief Use this macro wrapper for ov::element::IfTypeOf class to integrate it with
|
||||
* OpenVINO conditional compilation feature.
|
||||
*
|
||||
* @param region Region name for ITT which will be combined with TYPE_LIST_ prefix.
|
||||
* @param types List ov::element IfTypeOf class e.g. OV_PP_ET_LIST(f16, i8) to pack as one paramater.
|
||||
* @param visitor Class name of visitor which will be used by IfTypeOf<types>::visit(_VA_ARGS_) function.
|
||||
* @param ... List of parameters must match parameter list of `visit` function.
|
||||
*
|
||||
* @return Value returned by `visit` function
|
||||
*/
|
||||
|
||||
#if defined(SELECTIVE_BUILD_ANALYZER)
|
||||
# define IF_TYPE_OF(region, types, visitor, ...) \
|
||||
::ov::element::IfTypeOf<types>::apply<visitor>(OV_PP_TOSTRING(region), __VA_ARGS__)
|
||||
#elif defined(SELECTIVE_BUILD)
|
||||
# define IF_TYPE_OF(region, types, visitor, ...) \
|
||||
::ov::element::IfTypeOf<OV_PP_ET_LIST_OR_EMPTY(region)>::apply<visitor>(__VA_ARGS__)
|
||||
#else
|
||||
# define IF_TYPE_OF(region, types, visitor, ...) ::ov::element::IfTypeOf<types>::apply<visitor>(__VA_ARGS__)
|
||||
#endif
|
||||
|
@ -26,12 +26,14 @@ struct Evaluate : public element::NoAction<bool> {
|
||||
template <element::Type_t ET, class TI = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<CONVERT_ET_LIST>::apply<EvalByOutputType<is_lp_type(ET)>>(
|
||||
out.get_element_type(),
|
||||
reinterpret_cast<const TI*>(arg.data()),
|
||||
out,
|
||||
count,
|
||||
ET);
|
||||
return IF_TYPE_OF(Convert_out,
|
||||
CONVERT_ET_LIST,
|
||||
EvalByOutputType<is_lp_type(ET)>,
|
||||
out.get_element_type(),
|
||||
reinterpret_cast<const TI*>(arg.data()),
|
||||
out,
|
||||
count,
|
||||
ET);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -137,10 +139,13 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
|
||||
out.set_shape(in_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<CONVERT_ET_LIST>::apply<convert::Evaluate>(in.get_element_type(),
|
||||
in,
|
||||
out,
|
||||
shape_size(in_shape));
|
||||
return IF_TYPE_OF(v0_Convert_in_et,
|
||||
CONVERT_ET_LIST,
|
||||
convert::Evaluate,
|
||||
in.get_element_type(),
|
||||
in,
|
||||
out,
|
||||
shape_size(in_shape));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
#define SELECTIVE_BUILD_ANALYZER
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -51,6 +52,23 @@ TEST(conditional_compilation, collect_ops_in_opset) {
|
||||
#undef ov_opset_test_opset1_Abs
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestVisitor : public ov::element::NoAction<bool> {
|
||||
using ov::element::NoAction<bool>::visit;
|
||||
|
||||
template <ov::element::Type_t ET>
|
||||
static result_type visit(int x) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(conditional_compilation, IF_TYPE_OF_collect_action_for_supported_element) {
|
||||
using namespace ov::element;
|
||||
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
#undef SELECTIVE_BUILD_ANALYZER
|
||||
|
||||
#ifdef SELECTIVE_BUILD_ANALYZER_ON
|
||||
|
@ -16,6 +16,7 @@
|
||||
# undef SELECTIVE_BUILD
|
||||
#endif
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -45,6 +46,23 @@ TEST(conditional_compilation, all_ops_enabled_in_opset) {
|
||||
EXPECT_NE(opset.create_insensitive("Constant"), nullptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestVisitor : public ov::element::NoAction<bool> {
|
||||
using ov::element::NoAction<bool>::visit;
|
||||
|
||||
template <ov::element::Type_t ET>
|
||||
static result_type visit(int x) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(conditional_compilation, IF_TYPE_OF_action_for_supported_element) {
|
||||
using namespace ov::element;
|
||||
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
#ifdef SELECTIVE_BUILD_ANALYZER_ON
|
||||
# define SELECTIVE_BUILD_ANALYZER
|
||||
#elif defined(SELECTIVE_BUILD_ON)
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
#define SELECTIVE_BUILD
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -49,6 +50,45 @@ TEST(conditional_compilation, disabled_Constant_in_opset) {
|
||||
#undef ov_opset_test_opset3_Abs
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestVisitor : public ov::element::NoAction<bool> {
|
||||
using ov::element::NoAction<bool>::visit;
|
||||
|
||||
template <ov::element::Type_t ET>
|
||||
static result_type visit(int x) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(conditional_compilation, IF_TYPE_OF_element_type_on_cc_list) {
|
||||
#define TYPE_LIST_ov_eval_enabled_test_1 1
|
||||
#define TYPE_LIST_ov_eval_test_1 ::ov::element::f32, ::ov::element::u64
|
||||
|
||||
using namespace ov::element;
|
||||
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
#undef TYPE_LIST_ov_eval_enabled_test_1
|
||||
#undef TYPE_LIST_ov_eval_test_1
|
||||
}
|
||||
|
||||
TEST(conditional_compilation, IF_TYPE_OF_element_type_not_on_cc_list) {
|
||||
#define TYPE_LIST_ov_eval_enabled_test_1 1
|
||||
#define TYPE_LIST_ov_eval_test_1 f16
|
||||
|
||||
using namespace ov::element;
|
||||
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
|
||||
EXPECT_FALSE(result);
|
||||
|
||||
#undef TYPE_LIST_ov_eval_enabled_test_1
|
||||
#undef TYPE_LIST_ov_eval_test_1
|
||||
}
|
||||
|
||||
TEST(conditional_compilation, IF_TYPE_OF_no_element_list) {
|
||||
const auto result = IF_TYPE_OF(test_1, OV_PP_ET_LIST(f32), TestVisitor, ov::element::f32, 10);
|
||||
EXPECT_FALSE(result);
|
||||
}
|
||||
#undef SELECTIVE_BUILD
|
||||
|
||||
#ifdef SELECTIVE_BUILD_ANALYZER_ON
|
||||
|
@ -11,6 +11,7 @@
|
||||
using namespace testing;
|
||||
using namespace ov::element;
|
||||
|
||||
namespace {
|
||||
struct TestVisitor : public ov::element::NotSupported<bool> {
|
||||
using ov::element::NotSupported<bool>::visit;
|
||||
|
||||
@ -46,6 +47,7 @@ struct TestVisitorVoidReturn : public ov::element::NoAction<void> {
|
||||
};
|
||||
|
||||
int TestVisitorVoidReturn::test_value;
|
||||
} // namespace
|
||||
|
||||
class IfTypeOfTest : public Test {
|
||||
protected:
|
||||
|
Loading…
Reference in New Issue
Block a user