提交 14181bd4 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

generalize Reshape.infer_shape

上级 68631c38
......@@ -5355,40 +5355,33 @@ class Reshape(Op):
# Here, we only simplify if the shape (node.inputs[1]) is a constant,
# ideally it would suffice to check that it is always non-negative.
"""
oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = node.inputs[1][i]
if os_i == -1:
os_i = default_os_i
except TypeError:
print 'erreur'
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
"""
# In contrast with the preceding block, the following will handle
# an entry equal to -1 in desired shape
requ = node.inputs[1]
if isinstance(requ, theano.tensor.TensorConstant):
requ = list(requ.data)
requ_part = [ele for ele in requ if ele != -1]
crit = len(requ) - len(requ_part)
if crit == 1:
missing = numpy.prod(ishapes[0]) / numpy.prod(requ_part)
for i, ele in enumerate(requ):
if ele == -1:
requ[i] = missing
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
requ = node.inputs[1]
if isinstance(requ, theano.tensor.TensorConstant):
requ = list(requ.data)
requ_part = [ele for ele in requ if ele != -1]
crit = len(requ) - len(requ_part)
if crit == 1:
missing = numpy.prod(ishapes[0]) / numpy.prod(requ_part)
for i, ele in enumerate(requ):
if ele == -1:
requ[i] = missing
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1')
return [requ]
else:
return node.env.shape_feature.default_infer_shape(node, ishapes)
return [requ]
else:
oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self):
return (2,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论