提交 1d13344e authored 作者: Xavier Bouthillier's avatar Xavier Bouthillier

Merge pull request #3323 from SinaHonari/issue3031

theano.tensor.constant reshape fix
...@@ -4126,7 +4126,11 @@ class Reshape(Op): ...@@ -4126,7 +4126,11 @@ class Reshape(Op):
x = as_tensor_variable(x) x = as_tensor_variable(x)
shp_orig = shp shp_orig = shp
shp = as_tensor_variable(shp, ndim=1) shp = as_tensor_variable(shp, ndim=1)
if not shp.dtype.startswith('int'): if not (shp.dtype.startswith('int') or
(isinstance(shp, TensorConstant) and shp.data.size == 0)):
# It raises an error if shp is not of integer type,
# except when shp is constant and empty
# (in this case, shp.dtype does not matter anymore).
raise TypeError("Shape must be integers", shp, shp.dtype) raise TypeError("Shape must be integers", shp, shp.dtype)
assert shp.ndim == 1 assert shp.ndim == 1
if isinstance(shp, TensorConstant): if isinstance(shp, TensorConstant):
......
...@@ -5052,6 +5052,11 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5052,6 +5052,11 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
f = self.function([x], x.reshape((0, 100))) f = self.function([x], x.reshape((0, 100)))
assert f(numpy.ndarray((0,), dtype='float32')).shape == (0, 100) assert f(numpy.ndarray((0,), dtype='float32')).shape == (0, 100)
def test_empty_shp(self):
const = theano.tensor.constant([1]).reshape(())
f = function([], const)
assert f().shape == ()
def test_make_column_matrix_broadcastable(): def test_make_column_matrix_broadcastable():
# The goal of the operation made by `b` is to ensure the second dimension # The goal of the operation made by `b` is to ensure the second dimension
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论