提交 4cdd1632 authored 作者: James Bergstra's avatar James Bergstra

corrected error in elemwise re: passing node name to scalar_op.c_code

上级 04f2a477
......@@ -793,7 +793,7 @@ class Elemwise(Op):
rval.append(tuple(oshp))
return rval
def _c_all(self, node, name, inames, onames, sub):
def _c_all(self, node, nodename, inames, onames, sub):
_inames = inames
_onames = onames
......@@ -901,7 +901,7 @@ class Elemwise(Op):
Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for output in node.outputs]),
name + '_scalar_',
nodename + '_scalar_',
["%s_i" % s for s in _inames],
["%s_i" % s for s in onames],
sub)
......@@ -922,19 +922,20 @@ class Elemwise(Op):
sub = sub)
return decl, checks, alloc, loop
def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(node, name, inames, onames, sub))
def c_code(self, node, nodename, inames, onames, sub):
code = "\n".join(self._c_all(node, nodename, inames, onames, sub))
return code
def c_headers(self):
return ['<vector>', '<algorithm>']
def c_support_code(self):
support_code = self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename):
support_code = self.scalar_op.c_support_code_apply(node,
nodename + '_scalar_')
return support_code
def c_code_cache_version_apply(self, node):
version = [5] # the version corresponding to the c code in this Op
version = [6] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论