提交 283546f5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed C code cache version for GpuCAReduce

上级 e0cca017
......@@ -8,6 +8,7 @@ import numpy
import theano
from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar, config
from theano.scalar import Scalar
scal = scalar # somewhere scalar gets reassigned to be a function
from theano.gof.python25 import all, any
......@@ -529,7 +530,8 @@ class GpuCAReduce(GpuOp):
def __eq__(self, other):
return (type(self) == type(other) and
self.reduce_mask == other.reduce_mask)
self.reduce_mask == other.reduce_mask and
self.scalar_op == other.scalar_op)
def __hash__(self):
return hash(type(self)) ^ hash(self.reduce_mask)
......@@ -1528,14 +1530,19 @@ class GpuCAReduce(GpuOp):
}
""" % locals()
def c_code_cache_version(self):
op_version = self.scalar_op.cuda_assign_reduce_code_cache_version()
if op_version:
# our version is 0
rval = (0,) + op_version
return rval
def c_code_cache_version_apply(self, node):
version = [4] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else:
# we can't support caching if the op doesn't
return ()
def _op_guard(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论