提交 14181bd4 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

generalize Reshape.infer_shape

上级 68631c38
......@@ -5356,23 +5356,6 @@ class Reshape(Op):
# Here, we only simplify if the shape (node.inputs[1]) is a constant,
# ideally it would suffice to check that it is always non-negative.
"""
oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = node.inputs[1][i]
if os_i == -1:
os_i = default_os_i
except TypeError:
print 'erreur'
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
"""
# In contrast with the preceding block, the following will handle
# an entry equal to -1 in desired shape
requ = node.inputs[1]
if isinstance(requ, theano.tensor.TensorConstant):
requ = list(requ.data)
......@@ -5388,7 +5371,17 @@ class Reshape(Op):
' must have at most one entry equal to -1')
return [requ]
else:
return node.env.shape_feature.default_infer_shape(node, ishapes)
oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self):
return (2,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论