Removed wrong check from FC shape inference (#3332)

This commit is contained in:
Gleb Kazantaev 2020-11-25 07:01:58 +03:00 committed by GitHub
parent 5fc2724199
commit 2b70fa1473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 151 additions and 2 deletions

View File

@ -28,8 +28,6 @@ shared_ptr<Node> op::FullyConnected::clone_with_new_inputs(const OutputVector& n
}
void op::FullyConnected::validate_and_infer_types() {
if (m_output_shape.size() < 2)
throw ngraph_error("FullyConnected shape is incorrect");
m_output_size = m_output_shape.back();
set_output_type(
0,

View File

@ -564,3 +564,154 @@ TEST_F(NGraphReaderTests, ReadMatMulNetwork5) {
)V0G0N";
compareIRs(model, modelV5, 48);
}
TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
std::string model = R"V0G0N(
<net name="Network" version="10">
<layers>
<layer id="0" name="data" type="Parameter" version="opset1">
<data element_type="f32" shape="2048"/>
<output>
<port id="0" precision="FP32">
<dim>2048</dim>
</port>
</output>
</layer>
<layer id="1" name="embedded_input__const" type="Const" version="opset1">
<data element_type="f32" offset="0" shape="2048,1000" size="8192000"/>
<output>
<port id="1" precision="FP32">
<dim>2048</dim>
<dim>1000</dim>
</port>
</output>
</layer>
<layer id="3" name="fc" type="MatMul" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>2048</dim>
</port>
<port id="1" precision="FP32">
<dim>2048</dim>
<dim>1000</dim>
</port>
</input>
<output>
<port id="3" precision="FP32">
<dim>1000</dim>
</port>
</output>
</layer>
<layer name="output" type="Result" id="2" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>1000</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
<edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
<edge from-layer="3" from-port="3" to-layer="2" to-port="0"/>
</edges>
</net>
)V0G0N";
// 'fc' layer biases are fake and added due to IE limitation for Fully Connected layer
std::string modelV5 = R"V0G0N(
<?xml version="1.0"?>
<net name="Network" version="6" batch="1">
<layers>
<layer name="data" type="Input" precision="FP32" id="0">
<data originalLayersNames="data" />
<output>
<port id="0" precision="FP32">
<dim>2048</dim>
</port>
</output>
</layer>
<layer name="Constant_735" type="Const" precision="I64" id="1">
<output>
<port id="0" precision="I64">
<dim>2</dim>
</port>
</output>
<blobs>
<custom offset="0" size="16" precision="I64" />
</blobs>
</layer>
<layer name="fc/Reshape" type="Reshape" precision="FP32" id="2">
<data dim="" originalLayersNames="fc" />
<input>
<port id="0">
<dim>2048</dim>
</port>
<port id="1">
<dim>2</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>2048</dim>
</port>
</output>
</layer>
<layer name="FullyConnected_737" type="FullyConnected" precision="FP32" id="3">
<data originalLayersNames="fc" out-size="1000" />
<input>
<port id="0">
<dim>1</dim>
<dim>2048</dim>
</port>
</input>
<output>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>1000</dim>
</port>
</output>
<blobs>
<biases offset="16" size="4000" precision="FP32" />
<weights offset="4016" size="8192000" precision="FP32" />
</blobs>
</layer>
<layer name="Constant_738" type="Const" precision="I64" id="4">
<output>
<port id="0" precision="I64">
<dim>1</dim>
</port>
</output>
<blobs>
<custom offset="8196016" size="8" precision="I64" />
</blobs>
</layer>
<layer name="fc" type="Reshape" precision="FP32" id="5">
<data dim="" originalLayersNames="fc" />
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="1">
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1000</dim>
</port>
</output>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0" />
<edge from-layer="1" from-port="0" to-layer="2" to-port="1" />
<edge from-layer="2" from-port="2" to-layer="3" to-port="0" />
<edge from-layer="3" from-port="1" to-layer="5" to-port="0" />
<edge from-layer="4" from-port="0" to-layer="5" to-port="1" />
</edges>
</net>
)V0G0N";
compareIRs(model, modelV5, 8293000);
}