提交 489b04e4 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added tests for local_useless_subtensor AdvancedSubtensor1 cases

上级 8c0ea6ec
...@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector ...@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor1,
as_tensor_variable, as_tensor_variable,
inplace, inplace,
Join, Join,
...@@ -1713,6 +1714,27 @@ def test_local_useless_subtensor(): ...@@ -1713,6 +1714,27 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1) f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3) f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
for dims, res in (([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(T.arange(2), True),
(T.arange(2, -1), False),
(T.arange(1, 2), False)):
f = function([x], tensor.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, theano.tensor.SpecifyShape), dims
assert prog[1].op == tensor.exp, dims
assert len(prog) == 2, dims
else:
assert any([isinstance(node.op, AdvancedSubtensor1)
for node in prog])
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
class test_local_subtensor_make_vector(unittest.TestCase): class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self): def test_scalar_idx(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论