[LPT] Precision restriction customization extending (#17147)
* [LPT] Precision restriction customization extending * comments fix: refactoring
This commit is contained in:
@@ -37,20 +37,30 @@ namespace low_precision {
|
||||
class PrecisionsRestriction {
|
||||
public:
|
||||
using PrecisionsByPorts = std::vector<std::pair<std::vector<size_t>, std::vector<ngraph::element::Type>>>;
|
||||
using PrecisionsByPortsFunction = std::function<PrecisionsByPorts(const std::shared_ptr<Node>&)>;
|
||||
|
||||
ngraph::Node::type_info_t operationType;
|
||||
bool specifyVersion;
|
||||
PrecisionsByPorts precisionsByPorts;
|
||||
PrecisionsByPortsFunction precisionsByPortsFunction;
|
||||
|
||||
PrecisionsRestriction() = default;
|
||||
PrecisionsRestriction(
|
||||
const ngraph::Node::type_info_t operationType,
|
||||
const ngraph::Node::type_info_t& operationType,
|
||||
const bool specifyVersion,
|
||||
const PrecisionsByPorts& precisionsByPorts) :
|
||||
operationType(operationType),
|
||||
specifyVersion(specifyVersion),
|
||||
precisionsByPorts(precisionsByPorts) {}
|
||||
|
||||
PrecisionsRestriction(
|
||||
const ngraph::Node::type_info_t& operationType,
|
||||
const bool specifyVersion,
|
||||
const PrecisionsByPortsFunction& precisionsByPortsFunction) :
|
||||
operationType(operationType),
|
||||
specifyVersion(specifyVersion),
|
||||
precisionsByPortsFunction(precisionsByPortsFunction) {}
|
||||
|
||||
template <typename T>
|
||||
static PrecisionsRestriction create(
|
||||
const PrecisionsByPorts& precisionsByPorts,
|
||||
@@ -58,6 +68,13 @@ public:
|
||||
return PrecisionsRestriction(T::get_type_info_static(), specifyVersion, precisionsByPorts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static PrecisionsRestriction create(
|
||||
const PrecisionsByPortsFunction& precisionsByPortsFunction,
|
||||
const bool specifyVersion = false) {
|
||||
return PrecisionsRestriction(T::get_type_info_static(), specifyVersion, precisionsByPortsFunction);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static PrecisionsByPorts getPrecisionsByOperationType(std::vector<PrecisionsRestriction>& restrictions) {
|
||||
for (const auto& restriction : restrictions) {
|
||||
|
||||
@@ -38,13 +38,31 @@ class ngraph::pass::low_precision::MarkupPrecisions : public ngraph::pass::Funct
|
||||
public:
|
||||
class Restriction {
|
||||
public:
|
||||
class RestrictionByVersion {
|
||||
public:
|
||||
RestrictionByVersion() = default;
|
||||
RestrictionByVersion(
|
||||
const std::function<PrecisionsRestriction::PrecisionsByPorts(const std::shared_ptr<Node>&)>& precisionsFunction,
|
||||
const PrecisionsRestriction::PrecisionsByPorts& precisions) :
|
||||
precisionsFunction(precisionsFunction),
|
||||
precisions(precisions) {}
|
||||
|
||||
PrecisionsRestriction::PrecisionsByPorts get(const std::shared_ptr<Node>& node) const {
|
||||
return (precisionsFunction != nullptr) ? precisionsFunction(node) : precisions;
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<PrecisionsRestriction::PrecisionsByPorts(const std::shared_ptr<Node>&)> precisionsFunction;
|
||||
PrecisionsRestriction::PrecisionsByPorts precisions;
|
||||
};
|
||||
|
||||
explicit Restriction(const bool versionIsRequired) : versionIsRequired(versionIsRequired) {}
|
||||
void add(const std::string version_id, const ngraph::pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisions) {
|
||||
void add(const std::string version_id, const RestrictionByVersion& precisions) {
|
||||
precisionsByVersion.emplace(version_id, precisions);
|
||||
}
|
||||
|
||||
bool versionIsRequired;
|
||||
std::unordered_map<std::string, ngraph::pass::low_precision::PrecisionsRestriction::PrecisionsByPorts> precisionsByVersion;
|
||||
std::unordered_map<std::string, RestrictionByVersion> precisionsByVersion;
|
||||
};
|
||||
|
||||
OPENVINO_RTTI("MarkupPrecisions", "0");
|
||||
|
||||
@@ -30,10 +30,14 @@ ngraph::pass::low_precision::MarkupPrecisions::MarkupPrecisions(
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (it == restrictionsByOperation.end()) {
|
||||
Restriction r(restriction.specifyVersion);
|
||||
r.precisionsByVersion.emplace(restriction.operationType.version_id, restriction.precisionsByPorts);
|
||||
r.precisionsByVersion.emplace(
|
||||
restriction.operationType.version_id,
|
||||
Restriction::RestrictionByVersion(restriction.precisionsByPortsFunction, restriction.precisionsByPorts));
|
||||
restrictionsByOperation.emplace(restriction.operationType.name, r);
|
||||
} else {
|
||||
it->second.add(restriction.operationType.version_id, restriction.precisionsByPorts);
|
||||
it->second.add(
|
||||
restriction.operationType.version_id,
|
||||
Restriction::RestrictionByVersion(restriction.precisionsByPortsFunction, restriction.precisionsByPorts));
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
@@ -113,13 +117,13 @@ bool ngraph::pass::low_precision::MarkupPrecisions::run_on_model(const std::shar
|
||||
continue;
|
||||
}
|
||||
|
||||
const pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisionsByPorts = it2->second;
|
||||
setRestriction(node, precisionsByPorts);
|
||||
const auto& precisionsByPorts = it2->second;
|
||||
setRestriction(node, precisionsByPorts.get(node));
|
||||
} else {
|
||||
assert(r.precisionsByVersion.size() == 1ul);
|
||||
|
||||
const pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisionsByPorts = r.precisionsByVersion.begin()->second;
|
||||
setRestriction(node, precisionsByPorts);
|
||||
const auto& precisionsByPorts = r.precisionsByVersion.begin()->second;
|
||||
setRestriction(node, precisionsByPorts.get(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user