[IE][VPU]: Workaround to support parameter Beta for layer Swish (#2205)
* Workaround to full support Swish layer. It is faster than native Swish for now.
This commit is contained in:
parent
a0938a92d4
commit
9e8b42ff95
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include <vpu/middleend/pass_manager.hpp>
|
||||
#include <vpu/model/data_contents/replicated_data_content.hpp>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
@ -31,16 +32,33 @@ void PassImpl::run(const Model& model) {
|
||||
const auto outputData = swish->output(0);
|
||||
const auto name = swish->name();
|
||||
const auto& layer = swish->origLayer();
|
||||
const auto beta = swish->attrs().get<float>("beta");
|
||||
|
||||
model->removeStage(swish);
|
||||
auto sigmoidInput = inputData;
|
||||
|
||||
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", inputData->desc());
|
||||
if (beta != 1.0f) {
|
||||
const auto betaDesc = DataDesc(inputData->desc());
|
||||
const auto betaConst = model->addConstData(inputData->name() + "@beta", betaDesc,
|
||||
replicateContent(beta, betaDesc.totalDimSize(), betaDesc));
|
||||
const auto prodOutput = model->addNewData(inputData->name() + "@prod-x-beta", inputData->desc());
|
||||
_stageBuilder->addProdStage(
|
||||
model,
|
||||
name + "@prod-x-beta",
|
||||
layer,
|
||||
inputData,
|
||||
betaConst,
|
||||
prodOutput);
|
||||
sigmoidInput = prodOutput;
|
||||
}
|
||||
const auto sigmoidDesc = inputData->desc();
|
||||
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", sigmoidDesc);
|
||||
|
||||
_stageBuilder->addSigmoidStage(
|
||||
model,
|
||||
name + "@sigmoid",
|
||||
layer,
|
||||
{inputData},
|
||||
{sigmoidInput},
|
||||
{sigmoidOutput});
|
||||
_stageBuilder->addProdStage(
|
||||
model,
|
||||
|
@ -24,7 +24,7 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
||||
{Gelu, {}},
|
||||
{Mish, {}},
|
||||
{SoftPlus, {}},
|
||||
{Swish, {{1.0f}}} // {{0.05f}, {0.8f}, {1.0f}, {15.0f}}} #38489
|
||||
{Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}}
|
||||
};
|
||||
|
||||
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
|
||||
|
Loading…
Reference in New Issue
Block a user