提交 64a29098 authored 作者: James Bergstra's avatar James Bergstra

added version to softmax

上级 61fb2c8c
......@@ -25,6 +25,8 @@ def nvcc_kernel(name, params, body):
def code_version(version):
"""decorator to support version-based cache mechanism"""
if not isinstance(version, tuple):
raise TypeError('version must be tuple', version)
def deco(f):
f.code_version = version
return f
......@@ -32,7 +34,7 @@ def code_version(version):
UNVERSIONED = ()
@code_version(UNVERSIONED)
@code_version((1,))
def inline_reduce(N, buf, pos, count, manner_fn):
"""
Return C++ code for a function that reduces a contiguous buffer.
......@@ -103,7 +105,7 @@ def inline_reduce_prod(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "%s * %s"%(a,b))
@code_version(UNVERSIONED + inline_reduce_max.code_version + inline_reduce_sum.code_version)
@code_version((1,) + inline_reduce_max.code_version + inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount):
"""
:Precondition: buf and buf2 contain two identical copies of the input to softmax
......
......@@ -301,8 +301,8 @@ class GpuSoftmax (Op):
def make_node(self, x):
return Apply(self, [x],[x.type()])
def c_code_cache_version(self):
return ()
# reduce (1,) + device_softmax.code_version
#return ()
return (1,) + inline_softmax.code_version
def c_code(self, node, nodename, (x,), (z,), sub):
fail = sub['fail']
return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论