提交 bfae9c44 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3643 from SinaHonari/issue3420

Rewrite Reshape.infer_shape
...@@ -4401,6 +4401,11 @@ class Reshape(Op): ...@@ -4401,6 +4401,11 @@ class Reshape(Op):
# Here, we only simplify if the shape (node.inputs[1]) is a constant, # Here, we only simplify if the shape (node.inputs[1]) is a constant,
# ideally it would suffice to check that it is always non-negative. # ideally it would suffice to check that it is always non-negative.
# If current variable is a scalar and its dimensionality should
# change to self.ndim, then use size 1 for all new dimensions.
if len(ishapes[0]) == 0:
return [(1,) * self.ndim]
requ = node.inputs[1] requ = node.inputs[1]
if isinstance(requ, theano.tensor.TensorConstant): if isinstance(requ, theano.tensor.TensorConstant):
requ = list(requ.data) requ = list(requ.data)
...@@ -4418,17 +4423,12 @@ class Reshape(Op): ...@@ -4418,17 +4423,12 @@ class Reshape(Op):
' must have at most one entry equal to -1') ' must have at most one entry equal to -1')
return [requ] return [requ]
else: else:
oshape = [] new_dims = [node.inputs[1][i] for i in xrange(self.ndim)]
for i in xrange(self.ndim): return [tuple([switch(eq(new_dims[i], -1),
default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0]) theano.tensor.mul(*ishapes[0]) //
try: (-theano.tensor.mul(*new_dims)),
os_i = get_scalar_constant_value(node.inputs[1][i]).item() new_dims[i])
if os_i == -1: for i in xrange(self.ndim)])]
os_i = default_os_i
except NotScalarConstantError:
os_i = default_os_i
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (6,)
......
...@@ -5112,14 +5112,14 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5112,14 +5112,14 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
r = a.reshape(shapes, ndim=1) r = a.reshape(shapes, ndim=1)
z = zeros_like(r) z = zeros_like(r)
f = self.function([a, shapes], z.shape) f = self.function([a, shapes], r)
self.assertRaises(ValueError, f, a_val, [13]) self.assertRaises(ValueError, f, a_val, [13])
# Test reshape to 2 dim # Test reshape to 2 dim
r = a.reshape(shapes, ndim=2) r = a.reshape(shapes, ndim=2)
z = zeros_like(r) z = zeros_like(r)
f = self.function([a, shapes], z.shape) f = self.function([a, shapes], r)
self.assertRaises(ValueError, f, a_val, [-1, 5]) self.assertRaises(ValueError, f, a_val, [-1, 5])
self.assertRaises(ValueError, f, a_val, [7, -1]) self.assertRaises(ValueError, f, a_val, [7, -1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论