提交 1c8a3067 authored 作者: James Bergstra's avatar James Bergstra

Added Op.c_code_cache_version_apply(node) function that permits proper

[recursive] handling of code versioning for ops like Elemwise, CAReduce, etc. Changed the default cache version of Op to (). Added c_code_cache_version() functions to many Ops that were using the default before.
上级 d87f1f91
...@@ -815,7 +815,7 @@ class CLinker(link.Linker): ...@@ -815,7 +815,7 @@ class CLinker(link.Linker):
return (op_pos[i.owner], i.owner.outputs.index(i)) return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order): for opos, o in enumerate(order):
version.append(o.op.c_code_cache_version()) version.append(o.op.c_code_cache_version_apply(o))
for i in o.inputs: for i in o.inputs:
version.append(i.type.c_code_cache_version()) version.append(i.type.c_code_cache_version())
for i in o.outputs: for i in o.outputs:
......
...@@ -106,8 +106,27 @@ class CLinkerObject(object): ...@@ -106,8 +106,27 @@ class CLinkerObject(object):
The cache mechanism may erase cached modules that have been superceded by newer The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details. versions. See `ModuleCache` for details.
:note: See also `c_code_cache_version_apply()`
"""
return ()
def c_code_cache_version_apply(self, node):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
:note: See also `c_code_cache_version()`
:note: This function overrides `c_code_cache_version` unless it explicitly calls
`c_code_cache_version`. The default implementation simply calls `c_code_cache_version`
and ignores the `node` argument.
""" """
return (1,) return self.c_code_cache_version()
def c_compile_args(self): def c_compile_args(self):
"""Optional: Return a list of compile args recommended to compile the """Optional: Return a list of compile args recommended to compile the
......
...@@ -177,6 +177,16 @@ class CLinkerType(CLinkerObject): ...@@ -177,6 +177,16 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_sync", type(self), self.__class__.__name__) raise MethodNotDefined("c_sync", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return ()
......
...@@ -460,9 +460,16 @@ class TensorType(Type): ...@@ -460,9 +460,16 @@ class TensorType(Type):
def c_libraries(self): def c_libraries(self):
return [] return []
def c_support_code(cls): def c_support_code(self):
"""Override `CLinkerOp.c_support_code` """ """Override `CLinkerOp.c_support_code` """
return scal.Scalar("int8").c_support_code() return scal.Scalar(self.dtype).c_support_code()
def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version:
return (1,) + scalar_version
else:
return ()
# Easy constructors # Easy constructors
......
...@@ -252,6 +252,8 @@ class GemmRelated(Op): ...@@ -252,6 +252,8 @@ class GemmRelated(Op):
self.case_double_gemm, self.case_double_gemm,
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self):
return (1,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
...@@ -360,7 +362,14 @@ class Gemm(GemmRelated): ...@@ -360,7 +362,14 @@ class Gemm(GemmRelated):
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
gemm = Gemm() gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm')) pprint.assign(gemm, FunctionPrinter('gemm'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
...@@ -632,6 +641,9 @@ class Dot22(GemmRelated): ...@@ -632,6 +641,9 @@ class Dot22(GemmRelated):
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([T.dot]) @local_optimizer([T.dot])
......
...@@ -295,6 +295,9 @@ class DimShuffle(Op): ...@@ -295,6 +295,9 @@ class DimShuffle(Op):
return full_code % dict(locals(), **sub) return full_code % dict(locals(), **sub)
def c_code_cache_version(self):
return (1,)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
grad_order = ['x'] * len(x.type.broadcastable) grad_order = ['x'] * len(x.type.broadcastable)
...@@ -696,8 +699,20 @@ class Elemwise(Op): ...@@ -696,8 +699,20 @@ class Elemwise(Op):
def c_support_code(self): def c_support_code(self):
return self.scalar_op.c_support_code() return self.scalar_op.c_support_code()
def c_code_cache_version(self): def c_code_cache_version_apply(self, node):
return (4, 1) version = [4] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else:
return ()
# def elemwise_to_scal(env): # def elemwise_to_scal(env):
# mapping = {} # mapping = {}
...@@ -885,8 +900,20 @@ class CAReduce(Op): ...@@ -885,8 +900,20 @@ class CAReduce(Op):
code = "\n".join(self._c_all(node, name, inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
return code return code
def c_code_cache_version(self): def c_code_cache_version_apply(self, node):
return (1, 0) version = [2] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else:
return ()
class Sum(CAReduce): class Sum(CAReduce):
......
...@@ -40,6 +40,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -40,6 +40,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
? 1.0 ? 1.0
: 1.0 /(1.0+exp(-%(x)s));""" % locals() : 1.0 /(1.0+exp(-%(x)s));""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
def c_code_cache_version(self):
return (1,)
scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid') sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid')
...@@ -67,6 +69,8 @@ class ScalarSoftplus(scalar.UnaryScalarOp): ...@@ -67,6 +69,8 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
? %(x)s ? %(x)s
: log1p(exp(%(x)s));""" % locals() : log1p(exp(%(x)s));""" % locals()
raise NotImplementedError('only floating point x is implemented') raise NotImplementedError('only floating point x is implemented')
def c_code_cache_version(self):
return (1,)
scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus') scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus')
softplus = elemwise.Elemwise(scalar_softplus, name='softplus') softplus = elemwise.Elemwise(scalar_softplus, name='softplus')
...@@ -134,7 +138,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -134,7 +138,7 @@ class SoftmaxWithBias(gof.Op):
return ['<iostream>','<cmath>'] return ['<iostream>','<cmath>']
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (3,)
@staticmethod @staticmethod
def c_code_template(): def c_code_template():
# this implementation was lifted from # this implementation was lifted from
...@@ -295,7 +299,7 @@ class SoftmaxGrad(gof.Op): ...@@ -295,7 +299,7 @@ class SoftmaxGrad(gof.Op):
raise NotImplementedError() raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (3,)
def c_code(self, node, name, (dy, sm), (dx,), sub): def c_code(self, node, name, (dy, sm), (dx,), sub):
return ''' return '''
if ((%(dy)s->descr->type_num != PyArray_DOUBLE) && (%(dy)s->descr->type_num != PyArray_FLOAT)) if ((%(dy)s->descr->type_num != PyArray_DOUBLE) && (%(dy)s->descr->type_num != PyArray_FLOAT))
...@@ -633,7 +637,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -633,7 +637,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (2,)
def c_code(self, node, name, (x, b, y_idx), (nll, sm, am), sub): def c_code(self, node, name, (x, b, y_idx), (nll, sm, am), sub):
y_idx_type = node.inputs[2].type.dtype_specs()[1] y_idx_type = node.inputs[2].type.dtype_specs()[1]
am_type = y_idx_type am_type = y_idx_type
...@@ -665,7 +669,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -665,7 +669,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
def grad(self, *args): def grad(self, *args):
raise NotImplementedError() raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (2,)
def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
y_idx_type = node.inputs[2].type.dtype_specs()[1] y_idx_type = node.inputs[2].type.dtype_specs()[1]
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论