提交 1c8a3067 authored 作者: James Bergstra's avatar James Bergstra

Added Op.c_code_cache_version_apply(node) function that permits proper

[recursive] handling of code versioning for ops like Elemwise, CAReduce, etc. Changed the default cache version of Op to (). Added c_code_cache_version() functions to many Ops that were using the default before.
上级 d87f1f91
......@@ -815,7 +815,7 @@ class CLinker(link.Linker):
return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order):
version.append(o.op.c_code_cache_version())
version.append(o.op.c_code_cache_version_apply(o))
for i in o.inputs:
version.append(i.type.c_code_cache_version())
for i in o.outputs:
......
......@@ -106,8 +106,27 @@ class CLinkerObject(object):
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
:note: See also `c_code_cache_version_apply()`
"""
return ()
def c_code_cache_version_apply(self, node):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
:note: See also `c_code_cache_version()`
:note: This function overrides `c_code_cache_version` unless it explicitly calls
`c_code_cache_version`. The default implementation simply calls `c_code_cache_version`
and ignores the `node` argument.
"""
return (1,)
return self.c_code_cache_version()
def c_compile_args(self):
"""Optional: Return a list of compile args recommended to compile the
......
......@@ -177,6 +177,16 @@ class CLinkerType(CLinkerObject):
"""
raise MethodNotDefined("c_sync", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return ()
......
......@@ -460,9 +460,16 @@ class TensorType(Type):
def c_libraries(self):
return []
def c_support_code(cls):
def c_support_code(self):
"""Override `CLinkerOp.c_support_code` """
return scal.Scalar("int8").c_support_code()
return scal.Scalar(self.dtype).c_support_code()
def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version:
return (1,) + scalar_version
else:
return ()
# Easy constructors
......
......@@ -252,6 +252,8 @@ class GemmRelated(Op):
self.case_double_gemm,
self.end_switch_typenum), '')
def build_gemm_version(self):
return (1,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......@@ -360,7 +362,14 @@ class Gemm(GemmRelated):
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
def res_is_a(node, op, maxclients=None):
......@@ -632,6 +641,9 @@ class Dot22(GemmRelated):
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
_dot22 = Dot22()
@local_optimizer([T.dot])
......
......@@ -295,6 +295,9 @@ class DimShuffle(Op):
return full_code % dict(locals(), **sub)
def c_code_cache_version(self):
return (1,)
def grad(self, (x, ), (gz, )):
gz = as_tensor_variable(gz)
grad_order = ['x'] * len(x.type.broadcastable)
......@@ -696,8 +699,20 @@ class Elemwise(Op):
def c_support_code(self):
return self.scalar_op.c_support_code()
def c_code_cache_version(self):
return (4, 1)
def c_code_cache_version_apply(self, node):
version = [4] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else:
return ()
# def elemwise_to_scal(env):
# mapping = {}
......@@ -885,8 +900,20 @@ class CAReduce(Op):
code = "\n".join(self._c_all(node, name, inames, onames, sub))
return code
def c_code_cache_version(self):
return (1, 0)
def c_code_cache_version_apply(self, node):
version = [2] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for output in node.outputs])
version.extend(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
if all(version):
return tuple(version)
else:
return ()
class Sum(CAReduce):
......
......@@ -40,6 +40,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
? 1.0
: 1.0 /(1.0+exp(-%(x)s));""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def c_code_cache_version(self):
return (1,)
scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid')
......@@ -67,6 +69,8 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
? %(x)s
: log1p(exp(%(x)s));""" % locals()
raise NotImplementedError('only floating point x is implemented')
def c_code_cache_version(self):
return (1,)
scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus')
softplus = elemwise.Elemwise(scalar_softplus, name='softplus')
......@@ -134,7 +138,7 @@ class SoftmaxWithBias(gof.Op):
return ['<iostream>','<cmath>']
def c_code_cache_version(self):
return ()
return (3,)
@staticmethod
def c_code_template():
# this implementation was lifted from
......@@ -295,7 +299,7 @@ class SoftmaxGrad(gof.Op):
raise NotImplementedError()
def c_code_cache_version(self):
return ()
return (3,)
def c_code(self, node, name, (dy, sm), (dx,), sub):
return '''
if ((%(dy)s->descr->type_num != PyArray_DOUBLE) && (%(dy)s->descr->type_num != PyArray_FLOAT))
......@@ -633,7 +637,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
def c_code_cache_version(self):
return ()
return (2,)
def c_code(self, node, name, (x, b, y_idx), (nll, sm, am), sub):
y_idx_type = node.inputs[2].type.dtype_specs()[1]
am_type = y_idx_type
......@@ -665,7 +669,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
def grad(self, *args):
raise NotImplementedError()
def c_code_cache_version(self):
return ()
return (2,)
def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
y_idx_type = node.inputs[2].type.dtype_specs()[1]
return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论