提交 5649f8b6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Check that the op is one of those that we support before replacing reduce Ops.

上级 894692ea
import copy import copy
import theano, numpy import theano, numpy
from theano import tensor from theano import tensor, scalar
from theano.compile import optdb from theano.compile import optdb
from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB, from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer, toolbox, DestroyHandler, Optimizer, toolbox, DestroyHandler,
...@@ -176,6 +176,8 @@ def local_gpua_subtensor(node): ...@@ -176,6 +176,8 @@ def local_gpua_subtensor(node):
@register_opt() @register_opt()
@op_lifter(tensor.CAReduce) @op_lifter(tensor.CAReduce)
def local_gpua_careduce(node): def local_gpua_careduce(node):
if (isinstance(node.op.scalar_op, scalar.basic.Add) or
isinstance(node.op.scalar_op, scalar.basic.Mul)):
return GpuCAReduce(node.op.scalar_op, axis=node.op.axis, return GpuCAReduce(node.op.scalar_op, axis=node.op.axis,
dtype=getattr(node.op, 'dtype', None), dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None)) acc_dtype=getattr(node.op, 'acc_dtype', None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论