[LPT] Precision restriction customization extending (#17147)

* [LPT] Precision restriction customization extending

* comments fix: refactoring
This commit is contained in:
Edward Shogulin
2023-04-26 13:29:09 +01:00
committed by GitHub
parent a032d67cc7
commit e593cf8545
5 changed files with 62 additions and 12 deletions

View File

@@ -37,20 +37,30 @@ namespace low_precision {
class PrecisionsRestriction {
public:
using PrecisionsByPorts = std::vector<std::pair<std::vector<size_t>, std::vector<ngraph::element::Type>>>;
using PrecisionsByPortsFunction = std::function<PrecisionsByPorts(const std::shared_ptr<Node>&)>;
ngraph::Node::type_info_t operationType;
bool specifyVersion;
PrecisionsByPorts precisionsByPorts;
PrecisionsByPortsFunction precisionsByPortsFunction;
PrecisionsRestriction() = default;
PrecisionsRestriction(
const ngraph::Node::type_info_t operationType,
const ngraph::Node::type_info_t& operationType,
const bool specifyVersion,
const PrecisionsByPorts& precisionsByPorts) :
operationType(operationType),
specifyVersion(specifyVersion),
precisionsByPorts(precisionsByPorts) {}
PrecisionsRestriction(
const ngraph::Node::type_info_t& operationType,
const bool specifyVersion,
const PrecisionsByPortsFunction& precisionsByPortsFunction) :
operationType(operationType),
specifyVersion(specifyVersion),
precisionsByPortsFunction(precisionsByPortsFunction) {}
template <typename T>
static PrecisionsRestriction create(
const PrecisionsByPorts& precisionsByPorts,
@@ -58,6 +68,13 @@ public:
return PrecisionsRestriction(T::get_type_info_static(), specifyVersion, precisionsByPorts);
}
template <typename T>
static PrecisionsRestriction create(
const PrecisionsByPortsFunction& precisionsByPortsFunction,
const bool specifyVersion = false) {
return PrecisionsRestriction(T::get_type_info_static(), specifyVersion, precisionsByPortsFunction);
}
template <typename T>
static PrecisionsByPorts getPrecisionsByOperationType(std::vector<PrecisionsRestriction>& restrictions) {
for (const auto& restriction : restrictions) {

View File

@@ -38,13 +38,31 @@ class ngraph::pass::low_precision::MarkupPrecisions : public ngraph::pass::Funct
public:
class Restriction {
public:
class RestrictionByVersion {
public:
RestrictionByVersion() = default;
RestrictionByVersion(
const std::function<PrecisionsRestriction::PrecisionsByPorts(const std::shared_ptr<Node>&)>& precisionsFunction,
const PrecisionsRestriction::PrecisionsByPorts& precisions) :
precisionsFunction(precisionsFunction),
precisions(precisions) {}
PrecisionsRestriction::PrecisionsByPorts get(const std::shared_ptr<Node>& node) const {
return (precisionsFunction != nullptr) ? precisionsFunction(node) : precisions;
}
private:
std::function<PrecisionsRestriction::PrecisionsByPorts(const std::shared_ptr<Node>&)> precisionsFunction;
PrecisionsRestriction::PrecisionsByPorts precisions;
};
explicit Restriction(const bool versionIsRequired) : versionIsRequired(versionIsRequired) {}
void add(const std::string version_id, const ngraph::pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisions) {
void add(const std::string version_id, const RestrictionByVersion& precisions) {
precisionsByVersion.emplace(version_id, precisions);
}
bool versionIsRequired;
std::unordered_map<std::string, ngraph::pass::low_precision::PrecisionsRestriction::PrecisionsByPorts> precisionsByVersion;
std::unordered_map<std::string, RestrictionByVersion> precisionsByVersion;
};
OPENVINO_RTTI("MarkupPrecisions", "0");

View File

@@ -30,10 +30,14 @@ ngraph::pass::low_precision::MarkupPrecisions::MarkupPrecisions(
OPENVINO_SUPPRESS_DEPRECATED_START
if (it == restrictionsByOperation.end()) {
Restriction r(restriction.specifyVersion);
r.precisionsByVersion.emplace(restriction.operationType.version_id, restriction.precisionsByPorts);
r.precisionsByVersion.emplace(
restriction.operationType.version_id,
Restriction::RestrictionByVersion(restriction.precisionsByPortsFunction, restriction.precisionsByPorts));
restrictionsByOperation.emplace(restriction.operationType.name, r);
} else {
it->second.add(restriction.operationType.version_id, restriction.precisionsByPorts);
it->second.add(
restriction.operationType.version_id,
Restriction::RestrictionByVersion(restriction.precisionsByPortsFunction, restriction.precisionsByPorts));
}
OPENVINO_SUPPRESS_DEPRECATED_END
}
@@ -113,13 +117,13 @@ bool ngraph::pass::low_precision::MarkupPrecisions::run_on_model(const std::shar
continue;
}
const pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisionsByPorts = it2->second;
setRestriction(node, precisionsByPorts);
const auto& precisionsByPorts = it2->second;
setRestriction(node, precisionsByPorts.get(node));
} else {
assert(r.precisionsByVersion.size() == 1ul);
const pass::low_precision::PrecisionsRestriction::PrecisionsByPorts& precisionsByPorts = r.precisionsByVersion.begin()->second;
setRestriction(node, precisionsByPorts);
const auto& precisionsByPorts = r.precisionsByVersion.begin()->second;
setRestriction(node, precisionsByPorts.get(node));
}
}
}