提交 65079080 authored 作者: nouiz's avatar nouiz

Merge pull request #436 from delallea/composite_support_fix

Fixed crash with Composite support code and GPU
...@@ -410,6 +410,34 @@ def test_elemwise_composite_float64(): ...@@ -410,6 +410,34 @@ def test_elemwise_composite_float64():
assert not any([i.type.dtype=='float64' for i in s.inputs+s.outputs]) assert not any([i.type.dtype=='float64' for i in s.inputs+s.outputs])
def test_elemwise_composite_support_code():
"""
This was generating an error at compile time.
Commit 3d1690fa346103594356ecaeceeb2c6757b45d2b fixed that.
"""
X = tcn.shared_constructor(value=numpy.zeros((100, 10), dtype="float32"),
name='X')
W = tcn.shared_constructor(value=numpy.zeros((10, 1), dtype="float32"),
name='W')
U = T.dot(X, W)
Y = tcn.shared_constructor(value=numpy.zeros((100, 1), dtype="float32"),
name='Y')
P = T.exp(-(Y - U) ** 2)
epsilon = numpy.asarray(0.001, dtype="float32")
NLL = -T.mean(T.log(P + epsilon)) # SupportCodeError
G = T.grad(NLL, wrt=[W])
backup = theano.config.warn.identify_1pexp_bug
theano.config.warn.identify_1pexp_bug = False
try:
f_grad = theano.function(inputs=[], outputs=G, mode=mode_with_gpu)
finally:
theano.config.warn.identify_1pexp_bug = backup
f_grad()
topo = f_grad.maker.env.toposort()
assert sum([isinstance(node.op, T.Elemwise) for node in topo]) == 1
assert sum([isinstance(node.op, tcn.GpuElemwise) for node in topo]) == 1
def speed_elemwise_collapse(): def speed_elemwise_collapse():
......
...@@ -2422,10 +2422,11 @@ class Composite(ScalarOp): ...@@ -2422,10 +2422,11 @@ class Composite(ScalarOp):
rval = [] rval = []
for subnode, subnodename in zip(self.env.toposort(), self.nodenames): for subnode, subnodename in zip(self.env.toposort(), self.nodenames):
try: try:
rval.append( subnode_support_code = subnode.op.c_support_code_apply(
subnode.op.c_support_code_apply(
subnode, subnode,
subnodename % dict(nodename=name))) subnodename % dict(nodename=name))
if subnode_support_code:
rval.append(subnode_support_code)
except gof.utils.MethodNotDefined: except gof.utils.MethodNotDefined:
pass pass
# there should be no need to remove duplicate code blocks because # there should be no need to remove duplicate code blocks because
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论