提交 38199e38 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimization for register_opt2

上级 e9e07abd
......@@ -487,16 +487,12 @@ def local_gpua_multinomial(op, context_name, inputs, outputs):
@register_opt()
@op_lifter([theano.sandbox.multinomial.MultinomialWOReplacementFromUniform])
def local_gpua_multinomial_wor(node, context_name):
@register_opt2([theano.sandbox.multinomial.MultinomialWOReplacementFromUniform], 'fast_compile')
def local_gpua_multinomial_wor(op, context_name, inputs, outputs):
# TODO : need description for function
p, u, n = node.inputs
# try:
# if get_scalar_constant_value(n_samples) != 1:
# return None
# except NotScalarConstantError:
# return None
m, = node.outputs
p, u, n = inputs
m, = outputs
if ((p.dtype == u.dtype == 'float32') and (m.dtype == 'int64')):
gpu_op = GPUAMultinomialWOReplacementFromUniform(node.op.odtype)
gpu_op = GPUAMultinomialWOReplacementFromUniform(op.odtype)
return gpuarray.elemwise.GpuDimShuffle([False, False], [1, 0])(
gpu_op(p, u, n))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论