[IE][VPU]: Workaround to support parameter Beta for layer Swish (#2205)
* Workaround to full support Swish layer. It is faster than native Swish for now.
This commit is contained in:
parent
a0938a92d4
commit
9e8b42ff95
@ -3,6 +3,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include <vpu/middleend/pass_manager.hpp>
|
#include <vpu/middleend/pass_manager.hpp>
|
||||||
|
#include <vpu/model/data_contents/replicated_data_content.hpp>
|
||||||
|
|
||||||
namespace vpu {
|
namespace vpu {
|
||||||
|
|
||||||
@ -31,16 +32,33 @@ void PassImpl::run(const Model& model) {
|
|||||||
const auto outputData = swish->output(0);
|
const auto outputData = swish->output(0);
|
||||||
const auto name = swish->name();
|
const auto name = swish->name();
|
||||||
const auto& layer = swish->origLayer();
|
const auto& layer = swish->origLayer();
|
||||||
|
const auto beta = swish->attrs().get<float>("beta");
|
||||||
|
|
||||||
model->removeStage(swish);
|
model->removeStage(swish);
|
||||||
|
auto sigmoidInput = inputData;
|
||||||
|
|
||||||
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", inputData->desc());
|
if (beta != 1.0f) {
|
||||||
|
const auto betaDesc = DataDesc(inputData->desc());
|
||||||
|
const auto betaConst = model->addConstData(inputData->name() + "@beta", betaDesc,
|
||||||
|
replicateContent(beta, betaDesc.totalDimSize(), betaDesc));
|
||||||
|
const auto prodOutput = model->addNewData(inputData->name() + "@prod-x-beta", inputData->desc());
|
||||||
|
_stageBuilder->addProdStage(
|
||||||
|
model,
|
||||||
|
name + "@prod-x-beta",
|
||||||
|
layer,
|
||||||
|
inputData,
|
||||||
|
betaConst,
|
||||||
|
prodOutput);
|
||||||
|
sigmoidInput = prodOutput;
|
||||||
|
}
|
||||||
|
const auto sigmoidDesc = inputData->desc();
|
||||||
|
const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", sigmoidDesc);
|
||||||
|
|
||||||
_stageBuilder->addSigmoidStage(
|
_stageBuilder->addSigmoidStage(
|
||||||
model,
|
model,
|
||||||
name + "@sigmoid",
|
name + "@sigmoid",
|
||||||
layer,
|
layer,
|
||||||
{inputData},
|
{sigmoidInput},
|
||||||
{sigmoidOutput});
|
{sigmoidOutput});
|
||||||
_stageBuilder->addProdStage(
|
_stageBuilder->addProdStage(
|
||||||
model,
|
model,
|
||||||
|
@ -24,7 +24,7 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
|||||||
{Gelu, {}},
|
{Gelu, {}},
|
||||||
{Mish, {}},
|
{Mish, {}},
|
||||||
{SoftPlus, {}},
|
{SoftPlus, {}},
|
||||||
{Swish, {{1.0f}}} // {{0.05f}, {0.8f}, {1.0f}, {15.0f}}} #38489
|
{Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
|
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
|
||||||
|
Loading…
Reference in New Issue
Block a user