[IE][VPU]: Workaround to support parameter Beta for layer Swish (#2205)

* Workaround to full support Swish layer. It is faster than native Swish for now.
This commit is contained in:
Roman Vyunov (Intel) 2020-09-15 14:39:27 +03:00 committed by GitHub
parent a0938a92d4
commit 9e8b42ff95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 3 deletions

View File

@ -3,6 +3,7 @@
// //
#include <vpu/middleend/pass_manager.hpp> #include <vpu/middleend/pass_manager.hpp>
#include <vpu/model/data_contents/replicated_data_content.hpp>
namespace vpu { namespace vpu {
@ -31,16 +32,33 @@ void PassImpl::run(const Model& model) {
const auto outputData = swish->output(0); const auto outputData = swish->output(0);
const auto name = swish->name(); const auto name = swish->name();
const auto& layer = swish->origLayer(); const auto& layer = swish->origLayer();
const auto beta = swish->attrs().get<float>("beta");
model->removeStage(swish); model->removeStage(swish);
auto sigmoidInput = inputData;
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", inputData->desc()); if (beta != 1.0f) {
const auto betaDesc = DataDesc(inputData->desc());
const auto betaConst = model->addConstData(inputData->name() + "@beta", betaDesc,
replicateContent(beta, betaDesc.totalDimSize(), betaDesc));
const auto prodOutput = model->addNewData(inputData->name() + "@prod-x-beta", inputData->desc());
_stageBuilder->addProdStage(
model,
name + "@prod-x-beta",
layer,
inputData,
betaConst,
prodOutput);
sigmoidInput = prodOutput;
}
const auto sigmoidDesc = inputData->desc();
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", sigmoidDesc);
_stageBuilder->addSigmoidStage( _stageBuilder->addSigmoidStage(
model, model,
name + "@sigmoid", name + "@sigmoid",
layer, layer,
{inputData}, {sigmoidInput},
{sigmoidOutput}); {sigmoidOutput});
_stageBuilder->addProdStage( _stageBuilder->addProdStage(
model, model,

View File

@ -24,7 +24,7 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{Gelu, {}}, {Gelu, {}},
{Mish, {}}, {Mish, {}},
{SoftPlus, {}}, {SoftPlus, {}},
{Swish, {{1.0f}}} // {{0.05f}, {0.8f}, {1.0f}, {15.0f}}} #38489 {Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}}
}; };
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = { std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {