提交 dec15a5f authored 作者: James Bergstra's avatar James Bergstra

cuda elemwise generator made uncompilable code when an input appeared twice.

上级 a7ca0c4f
...@@ -834,14 +834,22 @@ nd_collapse_[i]=0; ...@@ -834,14 +834,22 @@ nd_collapse_[i]=0;
""" %locals() """ %locals()
#check that all inputs have valid dimensions #check that all inputs have valid dimensions
emitted_inames = {}
for id,iname in enumerate(inputs): for id,iname in enumerate(inputs):
if iname in emitted_inames:
assert emitted_inames[iname] is node.inputs[id]
continue
broadcasts = ', '.join(map(str,map(int,node.inputs[id].broadcastable))) broadcasts = ', '.join(map(str,map(int,node.inputs[id].broadcastable)))
nd = node.inputs[id].ndim nd = node.inputs[id].ndim
print >> sio, """ print >> sio, """
int broadcasts_%(iname)s[%(nd)s] = {%(broadcasts)s}; int broadcasts_%(iname)s[%(nd)s] = {%(broadcasts)s};
""" %locals() """ %locals()
emitted_inames[iname] = node.inputs[id]
#check that all inputs have valid dimensions #check that all inputs have valid dimensions
emitted_inames = {}
for id,iname in enumerate(inputs): for id,iname in enumerate(inputs):
if iname in emitted_inames:
continue
print >> sio, """ print >> sio, """
//std::cerr << "C_CODE %(opname)s checking input %(iname)s\\n"; //std::cerr << "C_CODE %(opname)s checking input %(iname)s\\n";
if (%(nd)s != %(iname)s->nd) if (%(nd)s != %(iname)s->nd)
...@@ -864,6 +872,7 @@ nd_collapse_[i]=0; ...@@ -864,6 +872,7 @@ nd_collapse_[i]=0;
} }
} }
""" %locals() """ %locals()
emitted_inames[iname] = True
#check that all outputs have valid dimensions #check that all outputs have valid dimensions
for oname in outputs: for oname in outputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论