提交 e05cd06e authored 作者: James Bergstra's avatar James Bergstra

fix: Elemwise returns all support code

Elemwise should return both c_support_code and c_support_code_apply for the scalar_op. It now does this and there is a comment explaining why.
上级 8d8ab48e
......@@ -2100,6 +2100,16 @@ class Composite(ScalarOp):
return ()
return tuple(rval)
def c_support_code(self):
rval = []
for subnode in self.env.toposort():
try:
rval.append(subnode.op.c_support_code())
except gof.utils.MethodNotDefined:
pass
# remove duplicate code blocks
return "\n".join(sorted(set(rval)))
def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
......@@ -2110,6 +2120,10 @@ class Composite(ScalarOp):
subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined:
pass
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)
def __eq__(self, other):
......
......@@ -929,6 +929,9 @@ class Elemwise(Op):
def c_headers(self):
return ['<vector>', '<algorithm>']
def c_support_code(self):
return self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename):
support_code = self.scalar_op.c_support_code_apply(node,
nodename + '_scalar_')
......
......@@ -4979,15 +4979,23 @@ def test_mod():
def test_mod_compile():
"""
This test generate an Elemwise of Composite as:
Elemwise{Composite{Composite{Composite{Composite{mod,EQ},Switch},mul},add}}
The c_code generated is not compiling as of 30 June 2010. I fix the compilation in the same commit.
Elemwise{
Composite{
Composite{
Composite{
Composite{mod,EQ},
Switch},
mul},
add}}
The c_code generated is not compiling as of 30 June 2010. I fix the
compilation in the same commit.
"""
x = tensor.vector()
y = tensor.vector()
shape = x.shape
out = tensor.switch(tensor.eq(3%x.shape[0],0),y,y[:-1])
out = tensor.switch(tensor.eq(3 % x.shape[0], 0), y, y[:-1])
f = theano.function([x,y],out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论