提交 3f3bf149 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3120 from lamblin/fix_clinker

Fix CLinker
...@@ -282,10 +282,16 @@ def get_nothing(r, name, sub): ...@@ -282,10 +282,16 @@ def get_nothing(r, name, sub):
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
# The declaration will be used by the Apply node that
if any([c != "output" and getattr(c.op, 'check_input', # is computing it (`r.owner`), and by each of the clients.
config.check_input) for (c, _) in r.clients]) or ( # If some of these have `check_input=True` in their `.op`,
r.owner and getattr(r.owner.op, 'check_input', True)): # it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
if ((any([getattr(c.op, 'check_input', config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)]) or
(r.owner and
getattr(r.owner.op, 'check_input', config.check_input)))):
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
else: else:
c_declare = r.type.c_declare(name, sub, False) c_declare = r.type.c_declare(name, sub, False)
...@@ -306,13 +312,21 @@ def get_c_init(r, name, sub): ...@@ -306,13 +312,21 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in # `c_extract` is called when getting the value of an apply node's
r.clients]): # input from the compute map, before being used by its clients.
# If one of the clients has `check_input=True`, we need to perform
# checks on the variable.
# However that code is not used by C code of the apply node creating
# this variable, so there is no need to check `r.owner.op.check_input`.
if any([getattr(c.op, 'check_input', config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)]):
# check_broadcast is just an hack to easily remove just the # check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. This check isn't # broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU. # done in the new GPU back-end or on the CPU.
if any([getattr(c.op, 'check_broadcast', True) for (c, _) in if any([getattr(c.op, 'check_broadcast', True)
r.clients]): for (c, _) in r.clients
if not isinstance(c, string_types)]):
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, True)
else: else:
try: try:
...@@ -333,10 +347,18 @@ def get_c_extract(r, name, sub): ...@@ -333,10 +347,18 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub): def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage.""" """Wrapper around c_extract_out that initializes py_name from storage."""
# `c_extract_out` is used to extract an output variable from
# the compute map, to be used as pre-allocated memory for `r`
# before its value gets computed.
# If the node producing `r` has `check_inputs=True`, it may
# also perform type checks on the initial value of the output,
# so we need to pass `check_input=True` to `c_extract_out`.
# However, that code is not used by potential clients of `r`,
# so we do not need to check them.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
# check_broadcast is just an hack to easily remove just the # check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. This check isn't # broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU. # done in the new GPU back-end or on the CPU.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
if getattr(r.owner.op, 'check_broadcast', True): if getattr(r.owner.op, 'check_broadcast', True):
c_extract = r.type.c_extract_out(name, sub, check_input) c_extract = r.type.c_extract_out(name, sub, check_input)
else: else:
...@@ -554,8 +576,10 @@ class CLinker(link.Linker): ...@@ -554,8 +576,10 @@ class CLinker(link.Linker):
# what to do at the beginning of each run, # what to do at the beginning of each run,
# what to do at the end of each run]] # what to do at the end of each run]]
if variable in self.inputs: if variable in self.inputs:
# we need to extract the new inputs at each run # We need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync.
# If the variable is both an input and an output, there is
# no need to synchronize either, it is already up-to-date.
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]] [get_c_declare, get_c_extract, get_c_cleanup]]
elif variable in self.orphans: elif variable in self.orphans:
...@@ -967,6 +991,10 @@ class CLinker(link.Linker): ...@@ -967,6 +991,10 @@ class CLinker(link.Linker):
if output_storage is None: if output_storage is None:
map = {} map = {}
output_storage = [] output_storage = []
# Initialize the map with the inputs, as some outputs may
# be inputs as well.
for i, variable in enumerate(self.inputs):
map[variable] = input_storage[i]
for variable in self.outputs: for variable in self.outputs:
if variable not in map: if variable not in map:
map[variable] = [None] map[variable] = [None]
......
...@@ -5,6 +5,8 @@ import unittest ...@@ -5,6 +5,8 @@ import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
import numpy
import theano import theano
from theano.gof.link import PerformLinker from theano.gof.link import PerformLinker
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
...@@ -362,3 +364,64 @@ def test_c_fail_error(): ...@@ -362,3 +364,64 @@ def test_c_fail_error():
print('Yay, TEST PASSED') print('Yay, TEST PASSED')
return # test passed return # test passed
assert 0 # test failed assert 0 # test failed
def test_shared_input_output():
# Test bug reported on the mailing list by Alberto Orlandi
# https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
# The shared variable is both an input and an output of the function.
inc = theano.tensor.iscalar('inc')
state = theano.shared(0)
state.name = 'state'
linker = theano.gof.CLinker()
mode = theano.Mode(linker=linker)
f = theano.function([inc], state, updates=[(state, state + inc)],
mode=mode)
g = theano.function([inc], state, updates=[(state, state + inc)])
# Initial value
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 0, (f0, g0)
# Increment state via f, returns the previous value.
f2 = f(2)
assert f2 == f0, (f2, f0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 2, (f0, g0)
# Increment state via g, returns the previous value
g3 = g(3)
assert g3 == g0, (g3, g0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 5, (f0, g0)
vstate = theano.shared(numpy.zeros(3, dtype='int32'))
vstate.name = 'vstate'
fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
mode=mode)
gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)])
# Initial value
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 0), fv0
assert numpy.all(gv0 == 0), gv0
# Increment state via f, returns the previous value.
fv2 = fv(2)
assert numpy.all(fv2 == fv0), (fv2, fv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 2), fv0
assert numpy.all(gv0 == 2), gv0
# Increment state via g, returns the previous value
gv3 = gv(3)
assert numpy.all(gv3 == gv0), (gv3, gv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 5), fv0
assert numpy.all(gv0 == 5), gv0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论