提交 6e632722 authored 作者: nouiz's avatar nouiz

Merge pull request #1289 from lamblin/long_reshape

Enable long ints in reshape
...@@ -6124,7 +6124,8 @@ def stack(*tensors): ...@@ -6124,7 +6124,8 @@ def stack(*tensors):
# See ticket #660 # See ticket #660
if numpy.all([ if numpy.all([
# in case there is direct int in tensors. # in case there is direct int in tensors.
isinstance(t, (numpy.number, float, int, python_complex)) or isinstance(t, (numpy.number, float, int, python_complex,
long)) or
(isinstance(t, Variable) and (isinstance(t, Variable) and
isinstance(t.type, TensorType) and isinstance(t.type, TensorType) and
t.ndim == 0) t.ndim == 0)
......
...@@ -5150,6 +5150,12 @@ class T_reshape(unittest.TestCase): ...@@ -5150,6 +5150,12 @@ class T_reshape(unittest.TestCase):
assert numpy.all(f_sub(a_val, b_val) == [2, 3]) assert numpy.all(f_sub(a_val, b_val) == [2, 3])
def test_reshape_long_in_shape(self):
v = vector('v')
r = v.reshape((v.shape[0], 1L))
print r.eval({v: numpy.arange(5.)})
assert numpy.allclose(r.eval({v: numpy.arange(5.)}).T, numpy.arange(5.))
def test_bad_shape(self): def test_bad_shape(self):
a = matrix('a') a = matrix('a')
shapes = ivector('shapes') shapes = ivector('shapes')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论