提交 64a29098 authored 作者: James Bergstra's avatar James Bergstra

added version to softmax

上级 61fb2c8c
...@@ -25,6 +25,8 @@ def nvcc_kernel(name, params, body): ...@@ -25,6 +25,8 @@ def nvcc_kernel(name, params, body):
def code_version(version): def code_version(version):
"""decorator to support version-based cache mechanism""" """decorator to support version-based cache mechanism"""
if not isinstance(version, tuple):
raise TypeError('version must be tuple', version)
def deco(f): def deco(f):
f.code_version = version f.code_version = version
return f return f
...@@ -32,7 +34,7 @@ def code_version(version): ...@@ -32,7 +34,7 @@ def code_version(version):
UNVERSIONED = () UNVERSIONED = ()
@code_version(UNVERSIONED) @code_version((1,))
def inline_reduce(N, buf, pos, count, manner_fn): def inline_reduce(N, buf, pos, count, manner_fn):
""" """
Return C++ code for a function that reduces a contiguous buffer. Return C++ code for a function that reduces a contiguous buffer.
...@@ -103,7 +105,7 @@ def inline_reduce_prod(N, buf, pos, count): ...@@ -103,7 +105,7 @@ def inline_reduce_prod(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "%s * %s"%(a,b)) return inline_reduce(N, buf, pos, count, lambda a, b: "%s * %s"%(a,b))
@code_version(UNVERSIONED + inline_reduce_max.code_version + inline_reduce_sum.code_version) @code_version((1,) + inline_reduce_max.code_version + inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount): def inline_softmax(N, buf, buf2, threadPos, threadCount):
""" """
:Precondition: buf and buf2 contain two identical copies of the input to softmax :Precondition: buf and buf2 contain two identical copies of the input to softmax
......
...@@ -301,8 +301,8 @@ class GpuSoftmax (Op): ...@@ -301,8 +301,8 @@ class GpuSoftmax (Op):
def make_node(self, x): def make_node(self, x):
return Apply(self, [x],[x.type()]) return Apply(self, [x],[x.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return () #return ()
# reduce (1,) + device_softmax.code_version return (1,) + inline_softmax.code_version
def c_code(self, node, nodename, (x,), (z,), sub): def c_code(self, node, nodename, (x,), (z,), sub):
fail = sub['fail'] fail = sub['fail']
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论