提交 46c2a563 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make GpuSplit c code work in float16

上级 94720cc4
...@@ -1341,6 +1341,8 @@ class GpuSplit(HideC, Split): ...@@ -1341,6 +1341,8 @@ class GpuSplit(HideC, Split):
Split for GPU. Split for GPU.
""" """
_f16_ok = True
def __init__(self, len_splits): def __init__(self, len_splits):
super(GpuSplit, self).__init__(len_splits) super(GpuSplit, self).__init__(len_splits)
# The GPU version of Split returns splits as views of the input. # The GPU version of Split returns splits as views of the input.
......
...@@ -358,9 +358,12 @@ class G_Join_and_Split(test_basic.T_Join_and_Split): ...@@ -358,9 +358,12 @@ class G_Join_and_Split(test_basic.T_Join_and_Split):
self.shared = shared self.shared = shared
def test_gpusplit_opt(self): def test_gpusplit_opt(self):
# Test that we move the node to the GPU
# Also test float16 computation at the same time.
rng = np.random.RandomState(seed=utt.fetch_seed()) rng = np.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX)) m = self.shared(rng.rand(4, 6).astype('float16'))
o = T.Split(2)(m, 0, [2, 2]) o = T.Split(2)(m, 0, [2, 2])
assert o[0].dtype == 'float16'
f = theano.function([], o, mode=self.mode) f = theano.function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op_class) assert any([isinstance(node.op, self.split_op_class)
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论