提交 d1705c09 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5229 from ChihebTrabelsi/ccw_take

fixing subtensor.take
......@@ -2384,6 +2384,8 @@ def take(a, indices, axis=None, mode='raise'):
shape = theano.tensor.concatenate(
[indices.shape, a.shape[axis + 1:]])
else:
if axis < 0:
axis += a.ndim
shape = theano.tensor.concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1
......
......@@ -1254,6 +1254,11 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
finally:
config.warn.inc_set_subtensor1 = orig_warn
def test_take(self):
a = tensor.matrix()
f = theano.function([a], a.take(0, axis=-1), allow_input_downcast=True)
x = f(numpy.random.normal(0, 1, (30, 4)))
class TestIncSubtensor1(unittest.TestCase):
# test inc_subtensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论