提交 14181bd4 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

generalize Reshape.infer_shape

上级 68631c38
...@@ -5355,40 +5355,33 @@ class Reshape(Op): ...@@ -5355,40 +5355,33 @@ class Reshape(Op):
# Here, we only simplify if the shape (node.inputs[1]) is a constant, # Here, we only simplify if the shape (node.inputs[1]) is a constant,
# ideally it would suffice to check that it is always non-negative. # ideally it would suffice to check that it is always non-negative.
""" requ = node.inputs[1]
oshape = [] if isinstance(requ, theano.tensor.TensorConstant):
for i in xrange(self.ndim): requ = list(requ.data)
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0]) requ_part = [ele for ele in requ if ele != -1]
try: crit = len(requ) - len(requ_part)
os_i = node.inputs[1][i] if crit == 1:
if os_i == -1: missing = numpy.prod(ishapes[0]) / numpy.prod(requ_part)
os_i = default_os_i for i, ele in enumerate(requ):
except TypeError: if ele == -1:
print 'erreur' requ[i] = missing
os_i = default_os_i elif crit > 1:
oshape.append(os_i) raise ValueError('shape argument to Reshape.perform'
return [tuple(oshape)]
"""
# In contrast with the preceding block, the following will handle
# an entry equal to -1 in desired shape
requ = node.inputs[1]
if isinstance(requ, theano.tensor.TensorConstant):
requ = list(requ.data)
requ_part = [ele for ele in requ if ele != -1]
crit = len(requ) - len(requ_part)
if crit == 1:
missing = numpy.prod(ishapes[0]) / numpy.prod(requ_part)
for i, ele in enumerate(requ):
if ele == -1:
requ[i] = missing
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1') ' must have at most one entry equal to -1')
return [requ] return [requ]
else: else:
return node.env.shape_feature.default_infer_shape(node, ishapes) oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except TypeError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论