提交 54938b8a authored 作者: Sina Honari's avatar Sina Honari

first commit for issue #3420

上级 21adebb5
...@@ -4417,17 +4417,12 @@ class Reshape(Op): ...@@ -4417,17 +4417,12 @@ class Reshape(Op):
' must have at most one entry equal to -1') ' must have at most one entry equal to -1')
return [requ] return [requ]
else: else:
oshape = [] new_dims = [node.inputs[1][i] for i in xrange(self.ndim)]
for i in xrange(self.ndim): return [tuple([switch(eq(new_dims[i], -1),
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0]) theano.tensor.mul(*ishapes[0]) /
try: (-theano.tensor.mul(*new_dims)),
os_i = get_scalar_constant_value(node.inputs[1][i]).item() new_dims[i])
if os_i == -1: for i in xrange(self.ndim)])]
os_i = default_os_i
except NotScalarConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (6,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论