[IE][VPU]: Implement HSwish layer with tests (#2775)
* Implement HSwish layer with tests * Disable HSwish decomposition by a predicate * Update vpu firmware
This commit is contained in:
parent
cabf8d8534
commit
07fbf93a0d
@ -19,7 +19,7 @@ set(VPU_SUPPORTED_FIRMWARES usb-ma2450 usb-ma2x8x pcie-ma248x)
|
|||||||
# Default packages
|
# Default packages
|
||||||
#
|
#
|
||||||
|
|
||||||
set(FIRMWARE_PACKAGE_VERSION 1430)
|
set(FIRMWARE_PACKAGE_VERSION 1440)
|
||||||
set(VPU_CLC_MA2X8X_VERSION "movi-cltools-20.09.1")
|
set(VPU_CLC_MA2X8X_VERSION "movi-cltools-20.09.1")
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -156,6 +156,7 @@ public:
|
|||||||
void parseSwish(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
void parseSwish(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
||||||
void parseActivation(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
void parseActivation(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
||||||
void parseLogicalNot(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
void parseLogicalNot(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
||||||
|
void parseHSwish(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Special layers
|
// Special layers
|
||||||
|
@ -171,6 +171,7 @@ VPU_DECLARE_ENUM(StageType,
|
|||||||
StridedSlice = 133,
|
StridedSlice = 133,
|
||||||
SoftPlus = 134,
|
SoftPlus = 134,
|
||||||
Swish = 135,
|
Swish = 135,
|
||||||
|
HSwish = 137,
|
||||||
)
|
)
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -128,6 +128,7 @@ FrontEnd::FrontEnd(StageBuilder::Ptr stageBuilder, const ie::ICore* core)
|
|||||||
{"SoftPlus", LAYER_PARSER(parseSoftPlus)},
|
{"SoftPlus", LAYER_PARSER(parseSoftPlus)},
|
||||||
{"Swish", LAYER_PARSER(parseSwish)},
|
{"Swish", LAYER_PARSER(parseSwish)},
|
||||||
{"Activation", LAYER_PARSER(parseActivation)},
|
{"Activation", LAYER_PARSER(parseActivation)},
|
||||||
|
{"HSwish", LAYER_PARSER(parseHSwish)},
|
||||||
}} {
|
}} {
|
||||||
VPU_THROW_UNLESS(_core != nullptr, "Argument core is null");
|
VPU_THROW_UNLESS(_core != nullptr, "Argument core is null");
|
||||||
}
|
}
|
||||||
@ -153,7 +154,8 @@ ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
|
|||||||
const bool casesWithDynamicOrStaticUsage =
|
const bool casesWithDynamicOrStaticUsage =
|
||||||
std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
|
std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
|
||||||
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
|
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
|
||||||
std::dynamic_pointer_cast<const ngraph::opset5::Minimum>(node);
|
std::dynamic_pointer_cast<const ngraph::opset5::Minimum>(node) ||
|
||||||
|
std::dynamic_pointer_cast<const ngraph::opset5::HSwish>(node);
|
||||||
|
|
||||||
const bool casesWithOnlyDynamicUsage =
|
const bool casesWithOnlyDynamicUsage =
|
||||||
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) ||
|
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) ||
|
||||||
@ -177,6 +179,7 @@ ie::ICNNNetwork::Ptr FrontEnd::convertNetwork(ie::ICNNNetwork& network) {
|
|||||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||||
|
|
||||||
manager.set_callback(transformationsPredicate);
|
manager.set_callback(transformationsPredicate);
|
||||||
manager.run_passes(nGraphFunc);
|
manager.run_passes(nGraphFunc);
|
||||||
|
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vpu/frontend/frontend.hpp>
|
||||||
|
#include <vpu/stages/post_op_stage.hpp>
|
||||||
|
|
||||||
|
namespace vpu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class HSwishStage final : public PostOpStage {
|
||||||
|
public:
|
||||||
|
using PostOpStage::PostOpStage;
|
||||||
|
|
||||||
|
private:
|
||||||
|
StagePtr cloneImpl() const override {
|
||||||
|
return std::make_shared<HSwishStage >(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void serializeParamsImpl(BlobSerializer&) const override {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void FrontEnd::parseHSwish(const Model& model, const ie::CNNLayerPtr& layer, const DataVector& inputs, const DataVector& outputs) const {
|
||||||
|
VPU_THROW_UNLESS((inputs.size() == 1),
|
||||||
|
"HSwish stage with name {} must have only 1 input, "
|
||||||
|
"actually provided {}", layer->name, inputs.size());
|
||||||
|
VPU_THROW_UNLESS(outputs.size() == 1,
|
||||||
|
"HSwish stage with name {} must have only 1 output, "
|
||||||
|
"actually provided {}", layer->name, outputs.size());
|
||||||
|
|
||||||
|
model->addNewStage<HSwishStage>(layer->name, StageType::HSwish, layer, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vpu
|
@ -24,7 +24,8 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
|||||||
{Gelu, {}},
|
{Gelu, {}},
|
||||||
{Mish, {}},
|
{Mish, {}},
|
||||||
{SoftPlus, {}},
|
{SoftPlus, {}},
|
||||||
{Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}}
|
{Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}},
|
||||||
|
{HSwish, {}},
|
||||||
};
|
};
|
||||||
|
|
||||||
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
|
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
|
||||||
|
Loading…
Reference in New Issue
Block a user