提交 3aedd9be authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Precompile some variant of Reduction kernel during Theano compilation.

上级 1d2a81bb
...@@ -302,6 +302,20 @@ class GpuCAReduce(HideC, CAReduceDtype): ...@@ -302,6 +302,20 @@ class GpuCAReduce(HideC, CAReduceDtype):
return Apply(res.op, [input], [otype()]) return Apply(res.op, [input], [otype()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
if self.axis is None:
redux = [True] * node.inputs[0].ndim
else:
redux = self.redux
acc_dtype = getattr(self, 'acc_dtype', None)
if acc_dtype is None:
acc_dtype = node.outputs[0].type.dtype
if any(redux):
node._cache_reduction_k = self.generate_kernel(node, acc_dtype,
redux)
return super(GpuCAReduce, self).make_thunk(node, storage_map,
compute_map, no_recycling)
def generate_kernel(self, node, odtype, redux): def generate_kernel(self, node, odtype, redux):
if isinstance(self.scalar_op, scalar.basic.Add): if isinstance(self.scalar_op, scalar.basic.Add):
reduce_expr = "a + b" reduce_expr = "a + b"
...@@ -311,7 +325,9 @@ class GpuCAReduce(HideC, CAReduceDtype): ...@@ -311,7 +325,9 @@ class GpuCAReduce(HideC, CAReduceDtype):
raise NotImplementedError() raise NotImplementedError()
return ReductionKernel(pygpu.get_default_context(), odtype, return ReductionKernel(pygpu.get_default_context(), odtype,
self.scalar_op.identity, reduce_expr, redux, self.scalar_op.identity, reduce_expr, redux,
arguments=[make_argument(node.inputs[0], 'a')]) arguments=[make_argument(node.inputs[0], 'a')],
init_nd=node.inputs[0].ndim
)
def perform(self, node, inp, out): def perform(self, node, inp, out):
input, = inp input, = inp
...@@ -322,14 +338,7 @@ class GpuCAReduce(HideC, CAReduceDtype): ...@@ -322,14 +338,7 @@ class GpuCAReduce(HideC, CAReduceDtype):
else: else:
redux = self.redux redux = self.redux
acc_dtype = getattr(self, 'acc_dtype', None)
if acc_dtype is None:
acc_dtype = node.outputs[0].type.dtype
if any(redux): if any(redux):
if not hasattr(node, '_cache_reduction_k'):
node._cache_reduction_k = self.generate_kernel(node, acc_dtype,
redux)
output[0] = node._cache_reduction_k(input).astype(copy=False, output[0] = node._cache_reduction_k(input).astype(copy=False,
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论