提交 e05cd06e authored 作者: James Bergstra's avatar James Bergstra

fix: Elemwise returns all support code

Elemwise should return both c_support_code and c_support_code_apply for the scalar_op. It now does this and there is a comment explaining why.
上级 8d8ab48e
...@@ -2100,6 +2100,16 @@ class Composite(ScalarOp): ...@@ -2100,6 +2100,16 @@ class Composite(ScalarOp):
return () return ()
return tuple(rval) return tuple(rval)
def c_support_code(self):
rval = []
for subnode in self.env.toposort():
try:
rval.append(subnode.op.c_support_code())
except gof.utils.MethodNotDefined:
pass
# remove duplicate code blocks
return "\n".join(sorted(set(rval)))
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
rval = [] rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames): for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
...@@ -2110,6 +2120,10 @@ class Composite(ScalarOp): ...@@ -2110,6 +2120,10 @@ class Composite(ScalarOp):
subnodename % dict(nodename=name))) subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined: except gof.utils.MethodNotDefined:
pass pass
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval) return "\n".join(rval)
def __eq__(self, other): def __eq__(self, other):
......
...@@ -929,6 +929,9 @@ class Elemwise(Op): ...@@ -929,6 +929,9 @@ class Elemwise(Op):
def c_headers(self): def c_headers(self):
return ['<vector>', '<algorithm>'] return ['<vector>', '<algorithm>']
def c_support_code(self):
return self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
support_code = self.scalar_op.c_support_code_apply(node, support_code = self.scalar_op.c_support_code_apply(node,
nodename + '_scalar_') nodename + '_scalar_')
......
...@@ -4979,15 +4979,23 @@ def test_mod(): ...@@ -4979,15 +4979,23 @@ def test_mod():
def test_mod_compile(): def test_mod_compile():
""" """
This test generate an Elemwise of Composite as: This test generate an Elemwise of Composite as:
Elemwise{Composite{Composite{Composite{Composite{mod,EQ},Switch},mul},add}} Elemwise{
Composite{
The c_code generated is not compiling as of 30 June 2010. I fix the compilation in the same commit. Composite{
Composite{
Composite{mod,EQ},
Switch},
mul},
add}}
The c_code generated is not compiling as of 30 June 2010. I fix the
compilation in the same commit.
""" """
x = tensor.vector() x = tensor.vector()
y = tensor.vector() y = tensor.vector()
shape = x.shape shape = x.shape
out = tensor.switch(tensor.eq(3%x.shape[0],0),y,y[:-1]) out = tensor.switch(tensor.eq(3 % x.shape[0], 0), y, y[:-1])
f = theano.function([x,y],out) f = theano.function([x,y],out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论