提交 3f3bf149 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3120 from lamblin/fix_clinker

Fix CLinker
......@@ -282,10 +282,16 @@ def get_nothing(r, name, sub):
def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
if any([c != "output" and getattr(c.op, 'check_input',
config.check_input) for (c, _) in r.clients]) or (
r.owner and getattr(r.owner.op, 'check_input', True)):
# The declaration will be used by the Apply node that
# is computing it (`r.owner`), and by each of the clients.
# If some of these have `check_input=True` in their `.op`,
# it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
if ((any([getattr(c.op, 'check_input', config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)]) or
(r.owner and
getattr(r.owner.op, 'check_input', config.check_input)))):
c_declare = r.type.c_declare(name, sub, True)
else:
c_declare = r.type.c_declare(name, sub, False)
......@@ -306,13 +312,21 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]):
# `c_extract` is called when getting the value of an apply node's
# input from the compute map, before being used by its clients.
# If one of the clients has `check_input=True`, we need to perform
# checks on the variable.
# However that code is not used by C code of the apply node creating
# this variable, so there is no need to check `r.owner.op.check_input`.
if any([getattr(c.op, 'check_input', config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)]):
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU.
if any([getattr(c.op, 'check_broadcast', True) for (c, _) in
r.clients]):
if any([getattr(c.op, 'check_broadcast', True)
for (c, _) in r.clients
if not isinstance(c, string_types)]):
c_extract = r.type.c_extract(name, sub, True)
else:
try:
......@@ -333,10 +347,18 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage."""
# `c_extract_out` is used to extract an output variable from
# the compute map, to be used as pre-allocated memory for `r`
# before its value gets computed.
# If the node producing `r` has `check_inputs=True`, it may
# also perform type checks on the initial value of the output,
# so we need to pass `check_input=True` to `c_extract_out`.
# However, that code is not used by potential clients of `r`,
# so we do not need to check them.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
if getattr(r.owner.op, 'check_broadcast', True):
c_extract = r.type.c_extract_out(name, sub, check_input)
else:
......@@ -554,8 +576,10 @@ class CLinker(link.Linker):
# what to do at the beginning of each run,
# what to do at the end of each run]]
if variable in self.inputs:
# we need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync
# We need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync.
# If the variable is both an input and an output, there is
# no need to synchronize either, it is already up-to-date.
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]]
elif variable in self.orphans:
......@@ -967,6 +991,10 @@ class CLinker(link.Linker):
if output_storage is None:
map = {}
output_storage = []
# Initialize the map with the inputs, as some outputs may
# be inputs as well.
for i, variable in enumerate(self.inputs):
map[variable] = input_storage[i]
for variable in self.outputs:
if variable not in map:
map[variable] = [None]
......
......@@ -5,6 +5,8 @@ import unittest
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
import numpy
import theano
from theano.gof.link import PerformLinker
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
......@@ -362,3 +364,64 @@ def test_c_fail_error():
print('Yay, TEST PASSED')
return # test passed
assert 0 # test failed
def test_shared_input_output():
# Test bug reported on the mailing list by Alberto Orlandi
# https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
# The shared variable is both an input and an output of the function.
inc = theano.tensor.iscalar('inc')
state = theano.shared(0)
state.name = 'state'
linker = theano.gof.CLinker()
mode = theano.Mode(linker=linker)
f = theano.function([inc], state, updates=[(state, state + inc)],
mode=mode)
g = theano.function([inc], state, updates=[(state, state + inc)])
# Initial value
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 0, (f0, g0)
# Increment state via f, returns the previous value.
f2 = f(2)
assert f2 == f0, (f2, f0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 2, (f0, g0)
# Increment state via g, returns the previous value
g3 = g(3)
assert g3 == g0, (g3, g0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 5, (f0, g0)
vstate = theano.shared(numpy.zeros(3, dtype='int32'))
vstate.name = 'vstate'
fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
mode=mode)
gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)])
# Initial value
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 0), fv0
assert numpy.all(gv0 == 0), gv0
# Increment state via f, returns the previous value.
fv2 = fv(2)
assert numpy.all(fv2 == fv0), (fv2, fv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 2), fv0
assert numpy.all(gv0 == 2), gv0
# Increment state via g, returns the previous value
gv3 = gv(3)
assert numpy.all(gv3 == gv0), (gv3, gv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 5), fv0
assert numpy.all(gv0 == 5), gv0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论