提交 54938b8a authored 作者: Sina Honari's avatar Sina Honari

first commit for issue #3420

上级 21adebb5
......@@ -4417,17 +4417,12 @@ class Reshape(Op):
' must have at most one entry equal to -1')
return [requ]
else:
oshape = []
for i in xrange(self.ndim):
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
try:
os_i = get_scalar_constant_value(node.inputs[1][i]).item()
if os_i == -1:
os_i = default_os_i
except NotScalarConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
new_dims = [node.inputs[1][i] for i in xrange(self.ndim)]
return [tuple([switch(eq(new_dims[i], -1),
theano.tensor.mul(*ishapes[0]) /
(-theano.tensor.mul(*new_dims)),
new_dims[i])
for i in xrange(self.ndim)])]
def c_code_cache_version(self):
return (6,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论