提交 575e4594 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rename split_op tp split_op_class

上级 34a639aa
...@@ -954,7 +954,7 @@ class T_Join_and_Split(theano.tensor.tests.test_basic.T_Join_and_Split): ...@@ -954,7 +954,7 @@ class T_Join_and_Split(theano.tensor.tests.test_basic.T_Join_and_Split):
self.mode = mode_with_gpu.excluding('constant_folding') self.mode = mode_with_gpu.excluding('constant_folding')
self.join_op = cuda.GpuJoin() self.join_op = cuda.GpuJoin()
# No gpu split. # No gpu split.
self.split_op = tensor.Split self.split_op_class = tensor.Split
# No Make vector on the gpu, Join used instead # No Make vector on the gpu, Join used instead
self.make_vector_op = cuda.GpuJoin() self.make_vector_op = cuda.GpuJoin()
self.floatX = "float32" self.floatX = "float32"
......
...@@ -364,7 +364,7 @@ class G_Join_and_Split(test_basic.T_Join_and_Split): ...@@ -364,7 +364,7 @@ class G_Join_and_Split(test_basic.T_Join_and_Split):
super(G_Join_and_Split, self).setUp() super(G_Join_and_Split, self).setUp()
self.mode = mode_with_gpu.excluding('constant_folding') self.mode = mode_with_gpu.excluding('constant_folding')
self.join_op = GpuJoin() self.join_op = GpuJoin()
self.split_op = GpuSplit self.split_op_class = GpuSplit
# Use join instead of MakeVector since there is no MakeVector on GPU # Use join instead of MakeVector since there is no MakeVector on GPU
self.make_vector_op = GpuJoin() self.make_vector_op = GpuJoin()
# this is to avoid errors with limited devices # this is to avoid errors with limited devices
......
...@@ -3193,7 +3193,7 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3193,7 +3193,7 @@ class T_Join_and_Split(unittest.TestCase):
'constant_folding' 'constant_folding'
) )
self.join_op = Join() self.join_op = Join()
self.split_op = Split self.split_op_class = Split
self.make_vector_op = opt.MakeVector() self.make_vector_op = opt.MakeVector()
self.floatX = config.floatX self.floatX = config.floatX
self.hide_error = theano.config.mode not in ['DebugMode', self.hide_error = theano.config.mode not in ['DebugMode',
...@@ -3784,9 +3784,9 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3784,9 +3784,9 @@ class T_Join_and_Split(unittest.TestCase):
def test_split_0elem(self): def test_split_0elem(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX)) m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [4, 0]) o = self.split_op_class(2)(m, 0, [4, 0])
f = function([], o, mode=self.mode) f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op) assert any([isinstance(node.op, self.split_op_class)
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
o1, o2 = f() o1, o2 = f()
assert numpy.allclose(o1, m.get_value(borrow=True)) assert numpy.allclose(o1, m.get_value(borrow=True))
...@@ -3795,9 +3795,9 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3795,9 +3795,9 @@ class T_Join_and_Split(unittest.TestCase):
def test_split_neg(self): def test_split_neg(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX)) m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [5, -1]) o = self.split_op_class(2)(m, 0, [5, -1])
f = function([], o, mode=self.mode) f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op) assert any([isinstance(node.op, self.split_op_class)
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f) self.assertRaises(ValueError, f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论