提交 b3cf9269 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #177 from nouiz/important_fix

Important fix Everything looks good
......@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
import StringIO, sys
import numpy
from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar
from theano import tensor, scalar, gof
import logging, copy
_logger_name = 'theano.sandbox.cuda.elemwise'
......@@ -42,8 +42,12 @@ class NaiveAlgo(object):
:param scalar_op: the scalar operation to execute on each element.
:param sync: if True, will wait after the kernel launch and check for error call.
"""
if scalar_op.c_support_code_apply(node=None, nodename="nodename"):
raise SupportCodeError(scalar_op)
try:
code = scalar_op.c_support_code_apply(node=None, name="nodename")
if code:
raise SupportCodeError(scalar_op)
except gof.utils.MethodNotDefined:
pass
self.scalar_op = scalar_op
self.sync = sync
self.inplace_pattern = inplace_pattern
......
......@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
return ()
return tuple(rval)
def c_support_code_apply(self, node, nodename):
def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
try:
rval.append(
subnode.op.c_support_code_apply(
subnode,
subnodename % dict(nodename=nodename)))
subnodename % dict(nodename=name)))
except gof.utils.MethodNotDefined:
pass
return "\n".join(rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论