提交 eda01ce9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make use of the relevant optimization explicit in TestLocalSubtensorMakeVector

上级 8a51bc8d
...@@ -560,21 +560,24 @@ class TestSubtensorIncSubtensor: ...@@ -560,21 +560,24 @@ class TestSubtensorIncSubtensor:
class TestLocalSubtensorMakeVector: class TestLocalSubtensorMakeVector:
mode = get_mode("FAST_RUN").including("local_subtensor_make_vector")
def test_scalar_idx(self): def test_scalar_idx(self):
x, y, z = lscalars("xyz") x, y, z = lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
f = function([x, y, z], v[0], mode=mode_opt) f = function([x, y, z], v[0], mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp) assert isinstance(prog[0].op, DeepCopyOp)
assert f(0, 1, 2) == 0 assert f(0, 1, 2) == 0
def test_idx_sybmolic(self): def test_idx_symbolic(self):
x, y, z = iscalars("xyz") x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z) v = MakeVector("int32")(x, y, z)
idx = aet.as_tensor([0], dtype=np.int64) idx = aet.as_tensor([0], dtype=np.int64)
f = function([x, y, z], v[idx], mode=mode_opt) f = function([x, y, z], v[idx], mode=self.mode)
opt_fgraph = f.maker.fgraph opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32" assert opt_fgraph.outputs[0].dtype == "int32"
...@@ -585,7 +588,7 @@ class TestLocalSubtensorMakeVector: ...@@ -585,7 +588,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_start(self): def test_slice_idx_start(self):
x, y, z = iscalars("xyz") x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z) v = MakeVector("int32")(x, y, z)
f = function([x, y, z], v[1:], mode=mode_opt, on_unused_input="ignore") f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore")
opt_fgraph = f.maker.fgraph opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32" assert opt_fgraph.outputs[0].dtype == "int32"
...@@ -597,7 +600,7 @@ class TestLocalSubtensorMakeVector: ...@@ -597,7 +600,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_stop(self): def test_slice_idx_stop(self):
x, y, z = lscalars("xyz") x, y, z = lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
f = function([x, y, z], v[:2], mode=mode_opt) f = function([x, y, z], v[:2], mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
...@@ -609,7 +612,7 @@ class TestLocalSubtensorMakeVector: ...@@ -609,7 +612,7 @@ class TestLocalSubtensorMakeVector:
def test_slice_idx_step(self): def test_slice_idx_step(self):
x, y, z = lscalars("xyz") x, y, z = lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
f = function([x, y, z], v[::2], mode=mode_opt) f = function([x, y, z], v[::2], mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
...@@ -621,7 +624,7 @@ class TestLocalSubtensorMakeVector: ...@@ -621,7 +624,7 @@ class TestLocalSubtensorMakeVector:
def test_AdvancedSubtensor1_idx(self): def test_AdvancedSubtensor1_idx(self):
x, y, z = lscalars("xyz") x, y, z = lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
f = function([x, y, z], v[[0, 2]], mode=mode_opt) f = function([x, y, z], v[[0, 2]], mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
...@@ -634,7 +637,7 @@ class TestLocalSubtensorMakeVector: ...@@ -634,7 +637,7 @@ class TestLocalSubtensorMakeVector:
x, y, z, q = lscalars("xyzq") x, y, z, q = lscalars("xyzq")
v = make_vector(x, y, z) v = make_vector(x, y, z)
q = make_vector(0, 2) q = make_vector(0, 2)
f = function([x, y, z], v[q], mode=mode_opt) f = function([x, y, z], v[q], mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论