提交 88223238 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added tests for local_subtensor_make_vector

上级 d9b990de
...@@ -1712,6 +1712,42 @@ def test_local_useless_subtensor(): ...@@ -1712,6 +1712,42 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 3) f([[1, 2, 3], [4, 5, 6]], 3)
class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self):
x, y, z = tensor.lscalars('xyz')
v = make_vector(x, y, z)
f = function([x, y, z], v[0], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, theano.compile.ops.DeepCopyOp)
assert f(0, 1, 2) == 0
def test_slice_idx_stop(self):
x, y, z = tensor.lscalars('xyz')
v = make_vector(x, y, z)
f = function([x, y, z], v[:2], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 1
def test_slice_idx_step(self):
x, y, z = tensor.lscalars('xyz')
v = make_vector(x, y, z)
f = function([x, y, z], v[::2], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
class test_local_subtensor_lift(unittest.TestCase): class test_local_subtensor_lift(unittest.TestCase):
def test0(self): def test0(self):
# basic test that the Op works # basic test that the Op works
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论