提交 22b5da98 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix argument definitions for outputs.

上级 65af716c
...@@ -167,21 +167,19 @@ class GpuElemwise(HideC, Elemwise): ...@@ -167,21 +167,19 @@ class GpuElemwise(HideC, Elemwise):
""" % dict(n=n, name='"%s"' % (name,), """ % dict(n=n, name='"%s"' % (name,),
typecode=i.type.typecode) typecode=i.type.typecode)
p = 0 p = len(inps)
for n, o in enumerate(node.outputs): for n, o in enumerate(node.outputs):
if n in self.inplace_pattern: if n in self.inplace_pattern:
assert(len(node.outputs) == 1) assert(len(node.outputs) == 1)
res += "\nargs[%(n)s].flags |= GE_WRITE;\n" % dict(n=self.inplace_pattern[n]) res += "\nargs[%(n)s].flags |= GE_WRITE;\n" % dict(n=self.inplace_pattern[n])
else: else:
nn = len(inps) + p
name = outs[p]
p += 1
res += """ res += """
args[%(n)s].name = %(name)s; args[%(n)s].name = %(name)s;
args[%(n)s].typecode = %(typecode)s; args[%(n)s].typecode = %(typecode)s;
args[%(n)s].flags = GE_WRITE; args[%(n)s].flags = GE_WRITE;
""" % dict(n=nn, name='"%s"' % (name,), """ % dict(n=p, name='"%s"' % (outs[n],),
typecode=o.type.typecode) typecode=o.type.typecode)
p += 1
res += """ res += """
ge = GpuElemwise_new(%(ctx)s->ctx, %(support)s, %(kop)s, %(nargs)s, args, %(nd)s, GE_CONVERT_F16); ge = GpuElemwise_new(%(ctx)s->ctx, %(support)s, %(kop)s, %(nargs)s, args, %(nd)s, GE_CONVERT_F16);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论