提交 5ca75d68 authored 作者: Frederic's avatar Frederic

fix the Composite change as it was making the buildbot crash.

We should always use the convention name for input to function that we override! There is more fix that are missing.
上级 1e8b0b00
...@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim ...@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
import StringIO, sys import StringIO, sys
import numpy import numpy
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar from theano import tensor, scalar, gof
import logging, copy import logging, copy
_logger_name = 'theano.sandbox.cuda.elemwise' _logger_name = 'theano.sandbox.cuda.elemwise'
...@@ -42,8 +42,12 @@ class NaiveAlgo(object): ...@@ -42,8 +42,12 @@ class NaiveAlgo(object):
:param scalar_op: the scalar operation to execute on each element. :param scalar_op: the scalar operation to execute on each element.
:param sync: if True, will wait after the kernel launch and check for error call. :param sync: if True, will wait after the kernel launch and check for error call.
""" """
if scalar_op.c_support_code_apply(node=None, nodename="nodename"): try:
code = scalar_op.c_support_code_apply(node=None, name="nodename")
if code:
raise SupportCodeError(scalar_op) raise SupportCodeError(scalar_op)
except gof.utils.MethodNotDefined:
pass
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.sync = sync self.sync = sync
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
......
...@@ -2097,14 +2097,14 @@ class Composite(ScalarOp): ...@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
return () return ()
return tuple(rval) return tuple(rval)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, name):
rval = [] rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames): for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
try: try:
rval.append( rval.append(
subnode.op.c_support_code_apply( subnode.op.c_support_code_apply(
subnode, subnode,
subnodename % dict(nodename=nodename))) subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined: except gof.utils.MethodNotDefined:
pass pass
return "\n".join(rval) return "\n".join(rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论