提交 bb1a9a1c authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: detect a class of situations we can't handle and crash rather…

map_variables: detect a class of situations we can't handle and crash rather than compute the wrong result
上级 b8349f77
...@@ -384,8 +384,6 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs): ...@@ -384,8 +384,6 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
# variables, which we cannot directly use inside the inner graph. # variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through. # we need to create inner inputs to access them through.
# TODO: handle potential updates of newly introduced shared variables.
outer_to_inner = dict(zip(outer_inputs, inner_inputs)) outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = [] extra_inner_inputs = []
extra_outer_inputs = [] extra_outer_inputs = []
...@@ -424,6 +422,16 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs): ...@@ -424,6 +422,16 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
replacements.append((input_, new_input)) replacements.append((input_, new_input))
for outer_input in foreign_inputs: for outer_input in foreign_inputs:
if getattr(outer_input, "update", False):
# when theano.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
raise NotImplementedError(
"Replacement introduces shared variable %s "
"which has an update associated with it. This "
"is not currently supported." % outer_input)
# if this foreign input is not already available # if this foreign input is not already available
# as an inner input, connect it through a new # as an inner input, connect it through a new
# inner input # inner input
......
import sys
import itertools import itertools
import numpy import numpy
import theano import theano
...@@ -86,6 +87,64 @@ class TestMapVariables(object): ...@@ -86,6 +87,64 @@ class TestMapVariables(object):
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), outer=0.5) rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), outer=0.5)
assert numpy.array_equal(rval, [[1, 3, 6], [-1, -3, -6]]) assert numpy.array_equal(rval, [[1, 3, 6], [-1, -3, -6]])
def test_scan_with_shared_update(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
return r
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
try:
s2, = map_variables(self.replacer, [s])
except NotImplementedError, e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
def test_scan_with_shared_update3(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return r + counter
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
try:
s2, = map_variables(self.replacer, [s])
except NotImplementedError, e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
def test_leaf_inside_scan(self): def test_leaf_inside_scan(self):
x = tensor.vector('x') x = tensor.vector('x')
y = tensor.scalar('y') y = tensor.scalar('y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论