提交 0b5759af authored 作者: Frederic Bastien's avatar Frederic Bastien

small code clean up to the new interface

上级 bc04c904
...@@ -9,11 +9,10 @@ from theano.tensor import NotScalarConstantError, get_scalar_constant_value ...@@ -9,11 +9,10 @@ from theano.tensor import NotScalarConstantError, get_scalar_constant_value
from theano.scalar import as_scalar from theano.scalar import as_scalar
import copy import copy
from theano.sandbox.cuda import cuda_available, GpuOp from theano.sandbox.cuda import cuda_available, GpuOp, register_opt
if cuda_available: if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType from theano.sandbox.cuda import CudaNdarrayType
from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host
from theano.sandbox.cuda.opt import register_opt
class MultinomialFromUniform(Op): class MultinomialFromUniform(Op):
...@@ -565,6 +564,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -565,6 +564,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
""" % locals() """ % locals()
@register_opt()
@local_optimizer([MultinomialFromUniform]) @local_optimizer([MultinomialFromUniform])
def local_gpu_multinomial(node): def local_gpu_multinomial(node):
# TODO : need description for function # TODO : need description for function
...@@ -608,7 +608,3 @@ def local_gpu_multinomial(node): ...@@ -608,7 +608,3 @@ def local_gpu_multinomial(node):
# The dimshuffle is on the cpu, but will be moved to the # The dimshuffle is on the cpu, but will be moved to the
# gpu by an opt. # gpu by an opt.
return [gpu_from_host(ret)] return [gpu_from_host(ret)]
if cuda_available:
register_opt()(local_gpu_multinomial)
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论