提交 eda01ce9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make use of the relevant optimization explicit in TestLocalSubtensorMakeVector

上级 8a51bc8d
......@@ -560,21 +560,24 @@ class TestSubtensorIncSubtensor:
class TestLocalSubtensorMakeVector:
mode = get_mode("FAST_RUN").including("local_subtensor_make_vector")
def test_scalar_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[0], mode=mode_opt)
f = function([x, y, z], v[0], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp)
assert f(0, 1, 2) == 0
def test_idx_sybmolic(self):
def test_idx_symbolic(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
idx = aet.as_tensor([0], dtype=np.int64)
f = function([x, y, z], v[idx], mode=mode_opt)
f = function([x, y, z], v[idx], mode=self.mode)
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
......@@ -585,7 +588,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_start(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
f = function([x, y, z], v[1:], mode=mode_opt, on_unused_input="ignore")
f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore")
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
......@@ -597,7 +600,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_stop(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[:2], mode=mode_opt)
f = function([x, y, z], v[:2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
......@@ -609,7 +612,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_step(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[::2], mode=mode_opt)
f = function([x, y, z], v[::2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
......@@ -621,7 +624,7 @@ class TestLocalSubtensorMakeVector:
def test_AdvancedSubtensor1_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[[0, 2]], mode=mode_opt)
f = function([x, y, z], v[[0, 2]], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
......@@ -634,7 +637,7 @@ class TestLocalSubtensorMakeVector:
x, y, z, q = lscalars("xyzq")
v = make_vector(x, y, z)
q = make_vector(0, 2)
f = function([x, y, z], v[q], mode=mode_opt)
f = function([x, y, z], v[q], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论