提交 a37a02dc authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

subtendor.take has been fixed.

上级 ecfc65ec
...@@ -2384,6 +2384,8 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -2384,6 +2384,8 @@ def take(a, indices, axis=None, mode='raise'):
shape = theano.tensor.concatenate( shape = theano.tensor.concatenate(
[indices.shape, a.shape[axis + 1:]]) [indices.shape, a.shape[axis + 1:]])
else: else:
if axis < 0:
axis += a.ndim
shape = theano.tensor.concatenate( shape = theano.tensor.concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]]) [a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
......
...@@ -1248,6 +1248,11 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -1248,6 +1248,11 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
finally: finally:
config.warn.inc_set_subtensor1 = orig_warn config.warn.inc_set_subtensor1 = orig_warn
def test_take(self):
a = tensor.matrix()
f = theano.function([a], a.take(0, axis=-1), allow_input_downcast=True)
x = f(numpy.random.normal(0, 1, (30, 4)))
class TestIncSubtensor1(unittest.TestCase): class TestIncSubtensor1(unittest.TestCase):
# test inc_subtensor # test inc_subtensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论