提交 cb343a76 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add an optimization to move mrg_uniform to the GPU.

While there fix the inplace optimization so that it actually works with the GPU versions.
上级 38ab51a1
...@@ -992,7 +992,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -992,7 +992,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, self.GpuKernelBase_version) return (3, self.GpuKernelBase_version)
def guess_n_streams(size, warn=True): def guess_n_streams(size, warn=True):
...@@ -1350,19 +1350,25 @@ class MRG_RandomStreams(object): ...@@ -1350,19 +1350,25 @@ class MRG_RandomStreams(object):
return final_samples return final_samples
from theano.sandbox.gpuarray.opt import (register_opt as register_gpua, from theano.sandbox.gpuarray.opt import (register_opt as register_gpua,
op_lifter as gpua_lifter) host_from_gpu as host_from_gpua)
@register_gpua() @register_gpua()
@gpua_lifter([mrg_uniform]) @local_optimizer([mrg_uniform])
def local_gpua_mrg(node): def local_gpua_mrg(node):
return GPUA_mrg_uniform.new(node.inputs[0], node.op.output_type.ndim, if (type(node.op) == mrg_uniform and
node.op.output_type.dtype, node.inputs[1]) isinstance(node.inputs[0].type, GpuArrayType)):
outs = GPUA_mrg_uniform.new(node.inputs[0],
node.op.output_type.ndim,
node.op.output_type.dtype,
node.inputs[1])
return [outs[0], host_from_gpua(outs[1])]
@local_optimizer([mrg_uniform]) MRG_RNGs = (mrg_uniform, GPU_mrg_uniform, GPUA_mrg_uniform)
@local_optimizer(MRG_RNGs)
def mrg_random_make_inplace(node): def mrg_random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, mrg_uniform) and not op.inplace: if isinstance(op, MRG_RNGs) and not op.inplace:
# op might be gpu version # op might be gpu version
new_op = op.__class__(op.output_type, inplace=True) new_op = op.__class__(op.output_type, inplace=True)
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论