提交 283546f5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed C code cache version for GpuCAReduce

上级 e0cca017
...@@ -8,6 +8,7 @@ import numpy ...@@ -8,6 +8,7 @@ import numpy
import theano import theano
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar, config from theano import tensor, scalar, config
from theano.scalar import Scalar
scal = scalar # somewhere scalar gets reassigned to be a function scal = scalar # somewhere scalar gets reassigned to be a function
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
...@@ -529,7 +530,8 @@ class GpuCAReduce(GpuOp): ...@@ -529,7 +530,8 @@ class GpuCAReduce(GpuOp):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.reduce_mask == other.reduce_mask) self.reduce_mask == other.reduce_mask and
self.scalar_op == other.scalar_op)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.reduce_mask) return hash(type(self)) ^ hash(self.reduce_mask)
...@@ -1528,14 +1530,19 @@ class GpuCAReduce(GpuOp): ...@@ -1528,14 +1530,19 @@ class GpuCAReduce(GpuOp):
} }
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version_apply(self, node):
op_version = self.scalar_op.cuda_assign_reduce_code_cache_version() version = [4] # the version corresponding to the c code in this Op
if op_version:
# our version is 0 # now we insert versions for the ops on which we depend...
rval = (0,) + op_version scalar_node = Apply(self.scalar_op,
return rval [Scalar(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else: else:
# we can't support caching if the op doesn't
return () return ()
def _op_guard(self): def _op_guard(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论