提交 64dd2e47 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix reshape.infer_shape

上级 856e2872
......@@ -3795,11 +3795,29 @@ class Reshape(Op):
# It's not trivial, because we would have to check if the product of
# all the non-minus-one shapes is a divisor of the product of the
# original shapes.
return [tuple([switch(eq(node.inputs[1][i], -1),
theano.tensor.opt.Shape_i(i)(node.outputs[0]),
node.inputs[1][i])
for i in range(self.ndim)]
)]
# The following expression leads to cycles in feature_shape,
# because it tries to replace the Shape_i node by the switch
# statement, which depends on Shape_i.
#return [tuple([switch(eq(node.inputs[1][i], -1),
# theano.tensor.opt.Shape_i(i)(node.outputs[0]),
# node.inputs[1][i])
# for i in range(self.ndim)]
# )]
# Here, we only simplify if the shape (node.inputs[1]) is a constant,
# ideally it would suffice to check that it is always non-negative.
oshape = []
for i in range(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论