提交 38199e38 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimization for register_opt2

上级 e9e07abd
...@@ -487,16 +487,12 @@ def local_gpua_multinomial(op, context_name, inputs, outputs): ...@@ -487,16 +487,12 @@ def local_gpua_multinomial(op, context_name, inputs, outputs):
@register_opt() @register_opt()
@op_lifter([theano.sandbox.multinomial.MultinomialWOReplacementFromUniform]) @op_lifter([theano.sandbox.multinomial.MultinomialWOReplacementFromUniform])
def local_gpua_multinomial_wor(node, context_name): @register_opt2([theano.sandbox.multinomial.MultinomialWOReplacementFromUniform], 'fast_compile')
def local_gpua_multinomial_wor(op, context_name, inputs, outputs):
# TODO : need description for function # TODO : need description for function
p, u, n = node.inputs p, u, n = inputs
# try: m, = outputs
# if get_scalar_constant_value(n_samples) != 1:
# return None
# except NotScalarConstantError:
# return None
m, = node.outputs
if ((p.dtype == u.dtype == 'float32') and (m.dtype == 'int64')): if ((p.dtype == u.dtype == 'float32') and (m.dtype == 'int64')):
gpu_op = GPUAMultinomialWOReplacementFromUniform(node.op.odtype) gpu_op = GPUAMultinomialWOReplacementFromUniform(op.odtype)
return gpuarray.elemwise.GpuDimShuffle([False, False], [1, 0])( return gpuarray.elemwise.GpuDimShuffle([False, False], [1, 0])(
gpu_op(p, u, n)) gpu_op(p, u, n))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论