提交 489b04e4 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added tests for local_useless_subtensor AdvancedSubtensor1 cases

上级 8c0ea6ec
......@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import (
AdvancedSubtensor1,
as_tensor_variable,
inplace,
Join,
......@@ -1713,6 +1714,27 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
for dims, res in (([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(T.arange(2), True),
(T.arange(2, -1), False),
(T.arange(1, 2), False)):
f = function([x], tensor.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, theano.tensor.SpecifyShape), dims
assert prog[1].op == tensor.exp, dims
assert len(prog) == 2, dims
else:
assert any([isinstance(node.op, AdvancedSubtensor1)
for node in prog])
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论