Remove leaky relu alpha check (#6910)
* remove leaky relu check * make alpha_node scalar
This commit is contained in:
parent
b11a2220b0
commit
3826a0d08d
@ -21,11 +21,8 @@ namespace ngraph
|
|||||||
auto data = node.get_ng_inputs().at(0);
|
auto data = node.get_ng_inputs().at(0);
|
||||||
double alpha = node.get_attribute_value<double>("alpha", 0.01);
|
double alpha = node.get_attribute_value<double>("alpha", 0.01);
|
||||||
|
|
||||||
CHECK_VALID_NODE(
|
|
||||||
node, alpha >= 0 && alpha <= 1, " alpha value should be in range (0,1)");
|
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> alpha_node =
|
std::shared_ptr<ngraph::Node> alpha_node =
|
||||||
default_opset::Constant::create(data.get_element_type(), Shape{}, {alpha});
|
default_opset::Constant::create(data.get_element_type(), Shape{1}, {alpha});
|
||||||
return {std::make_shared<default_opset::PRelu>(data, alpha_node)};
|
return {std::make_shared<default_opset::PRelu>(data, alpha_node)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user