提交 580e1614 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added supports_c_code method to GpuCAReduce

上级 0c932d58
...@@ -8,6 +8,7 @@ import numpy ...@@ -8,6 +8,7 @@ import numpy
import theano import theano
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar, config from theano import tensor, scalar, config
scal = scalar # somewhere scalar gets reassigned to be a function
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
...@@ -506,6 +507,7 @@ class GpuCAReduce(GpuOp): ...@@ -506,6 +507,7 @@ class GpuCAReduce(GpuOp):
be removed during graph optimization be removed during graph optimization
""" """
def __init__(self, reduce_mask, scalar_op): def __init__(self, reduce_mask, scalar_op):
self.reduce_mask = tuple(reduce_mask) self.reduce_mask = tuple(reduce_mask)
self.scalar_op = scalar_op self.scalar_op = scalar_op
...@@ -534,8 +536,47 @@ class GpuCAReduce(GpuOp): ...@@ -534,8 +536,47 @@ class GpuCAReduce(GpuOp):
x, = inp x, = inp
z, = out z, = out
self._op_guard() self._op_guard()
# reduce_max is declared but does nothing but
# raise NotImplementedError.
# We can't call it here anyway because it hasn't
# been added to the python bindings yet
z[0] = x.reduce_sum(self.reduce_mask) z[0] = x.reduce_sum(self.reduce_mask)
def supports_c_code(self, inputs):
""" Returns True if the current op and reduce pattern
has functioning C code """
# If we don't even have the right method, we certainly
# don't support the C code
# (This is the test that used to be implemented by
# local_gpu_sum)
pattern = (''.join(str(i) for i in self.reduce_mask))
if not hasattr(self, 'c_code_reduce_%s' % pattern):
return False
# Now that this is a general reduction op, we might
# have a method for a pattern, but that pattern
# might not be implemented for the current scalar op.
# To detect this more complicated situation, we
# make fake arguments to c_code, try to run them,
# and see if NotImplementedError gets raised.
node = self.make_node(*inputs)
name = 'fake_name'
inp = ['fake_input_name_%d' % i for i in xrange(len(inputs))]
out = ['fake_output_name_%d' % i for i in xrange(len(node.outputs))]
sub = { 'fail' : 'fake failure code' }
try:
self.c_code(node, name, inp, out, sub)
self.c_support_code_apply(node, name)
except NotImplementedError:
return False
return True
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
self._op_guard() self._op_guard()
x, = inp x, = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论