提交 575e4594 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rename split_op tp split_op_class

上级 34a639aa
......@@ -954,7 +954,7 @@ class T_Join_and_Split(theano.tensor.tests.test_basic.T_Join_and_Split):
self.mode = mode_with_gpu.excluding('constant_folding')
self.join_op = cuda.GpuJoin()
# No gpu split.
self.split_op = tensor.Split
self.split_op_class = tensor.Split
# No Make vector on the gpu, Join used instead
self.make_vector_op = cuda.GpuJoin()
self.floatX = "float32"
......
......@@ -364,7 +364,7 @@ class G_Join_and_Split(test_basic.T_Join_and_Split):
super(G_Join_and_Split, self).setUp()
self.mode = mode_with_gpu.excluding('constant_folding')
self.join_op = GpuJoin()
self.split_op = GpuSplit
self.split_op_class = GpuSplit
# Use join instead of MakeVector since there is no MakeVector on GPU
self.make_vector_op = GpuJoin()
# this is to avoid errors with limited devices
......
......@@ -3193,7 +3193,7 @@ class T_Join_and_Split(unittest.TestCase):
'constant_folding'
)
self.join_op = Join()
self.split_op = Split
self.split_op_class = Split
self.make_vector_op = opt.MakeVector()
self.floatX = config.floatX
self.hide_error = theano.config.mode not in ['DebugMode',
......@@ -3784,9 +3784,9 @@ class T_Join_and_Split(unittest.TestCase):
def test_split_0elem(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [4, 0])
o = self.split_op_class(2)(m, 0, [4, 0])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
assert any([isinstance(node.op, self.split_op_class)
for node in f.maker.fgraph.toposort()])
o1, o2 = f()
assert numpy.allclose(o1, m.get_value(borrow=True))
......@@ -3795,9 +3795,9 @@ class T_Join_and_Split(unittest.TestCase):
def test_split_neg(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [5, -1])
o = self.split_op_class(2)(m, 0, [5, -1])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
assert any([isinstance(node.op, self.split_op_class)
for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论