提交 001938e4 authored 作者: Shawn Tan's avatar Shawn Tan

Make sure AllocDiag is not in the graph when testing shape.

上级 ed59b021
...@@ -7599,17 +7599,18 @@ def test_alloc_diag(): ...@@ -7599,17 +7599,18 @@ def test_alloc_diag():
# Test infer_shape # Test infer_shape
f_shape = theano.function([x], adiag_op(x).shape) f_shape = theano.function([x], adiag_op(x).shape)
assert isinstance(topo[0].op, DeepCopyOp)
theano.printing.debugprint(f_shape.maker.fgraph.outputs[0]) theano.printing.debugprint(f_shape.maker.fgraph.outputs[0])
output_shape = f_shape(test_val) output_shape = f_shape(test_val)
assert np.all(diag_arr.shape == output_shape) assert any(isinstance(node.op, AllocDiag)
for node in f_shape.maker.fgraph.toposort()])
rediag_shape = np.diagonal( rediag_shape = np.diagonal(
np.ones(output_shape), np.ones(output_shape),
offset=offset, offset=offset,
axis1=axis1, axis1=axis1,
axis2=axis2 axis2=axis2
).shape ).shape
print(rediag_shape)
print(test_val.shape)
assert np.all(rediag_shape == test_val.shape) assert np.all(rediag_shape == test_val.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论