提交 ddfa25b9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix and disable optimization that move Multinomial to the gpu as it is bugged.

上级 64030399
import theano import theano
from theano import Op, Apply from theano import Op, Apply
import theano.tensor as T import theano.tensor as T
from theano.tensor.opt import register_specialize
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available, cuda_enabled from theano.sandbox.cuda import cuda_available
if cuda_available: if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType from theano.sandbox.cuda import CudaNdarrayType
from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host
from theano.sandbox.cuda.opt import register_opt
class MultinomialFromUniform(Op): class MultinomialFromUniform(Op):
'''Converts samples from a uniform into sample from a multinomial.''' '''Converts samples from a uniform into sample from a multinomial.'''
...@@ -283,11 +283,15 @@ class GpuMultinomialFromUniform(MultinomialFromUniform): ...@@ -283,11 +283,15 @@ class GpuMultinomialFromUniform(MultinomialFromUniform):
@local_optimizer() @local_optimizer()
def use_gpu_multinomial(node): def use_gpu_multinomial(node):
if isinstance(node.op, MultinomialFromUniform): if type(node.op) is MultinomialFromUniform:
p, u = node.inputs p, u = node.inputs
m, = node.outputs m, = node.outputs
if p.dtype == u.dtype == m.dtype == 'float32': if (p.dtype == u.dtype == m.dtype == 'float32' and
gpu_op = GpuMultinomialFromUniform(op.odtype) any([i.owner and isinstance(i.owner.op, theano.sandbox.cuda.HostFromGpu)
for i in node.inputs])):
gpu_op = GpuMultinomialFromUniform(node.op.odtype)
return [host_from_gpu(gpu_op(*[gpu_from_host(i) for i in node.inputs]))] return [host_from_gpu(gpu_op(*[gpu_from_host(i) for i in node.inputs]))]
if cuda_enabled:#theano.config.device.startswith('gpu'): if cuda_available:
register_specialize(use_gpu_multinomial) # Currently this it is bugged!
#register_opt()(use_gpu_multinomial)
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论