提交 77236f47 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4419 from slefrancois/function_copy

Function copy
...@@ -274,9 +274,6 @@ In practice, a good way of thinking about the ``givens`` is as a mechanism ...@@ -274,9 +274,6 @@ In practice, a good way of thinking about the ``givens`` is as a mechanism
that allows you to replace any part of your formula with a different that allows you to replace any part of your formula with a different
expression that evaluates to a tensor of same shape and dtype. expression that evaluates to a tensor of same shape and dtype.
.. _using_random_numbers:
.. note:: .. note::
Theano shared variable broadcast pattern default to False for each Theano shared variable broadcast pattern default to False for each
...@@ -285,6 +282,57 @@ expression that evaluates to a tensor of same shape and dtype. ...@@ -285,6 +282,57 @@ expression that evaluates to a tensor of same shape and dtype.
different pattern, just pass it as a parameter different pattern, just pass it as a parameter
``theano.shared(..., broadcastable=(True, False))`` ``theano.shared(..., broadcastable=(True, False))``
Copying functions
=================
Theano functions can be copied, which can be useful for creating similar
functions but with different shared variables or updates. This is done using
the :func:`copy()<theano.compile.function_module.Function.copy>` method of ``function`` objects. The optimized graph of the original function is copied,
so compilation only needs to be performed once.
Let's start from the accumulator defined above:
>>> import theano
>>> import theano.tensor as T
>>> state = theano.shared(0)
>>> inc = T.iscalar('inc')
>>> accumulator = theano.function([inc], state, updates=[(state, state+inc)])
We can use it to increment the state as usual:
>>> accumulator(10)
array(0)
>>> print(state.get_value())
10
We can use ``copy()`` to create a similar accumulator but with its own internal state
using the ``swap`` parameter, which is a dictionary of shared variables to exchange:
>>> new_state = theano.shared(0)
>>> new_accumulator = accumulator.copy(swap={state:new_state})
>>> new_accumulator(100)
[array(0)]
>>> print(new_state.get_value())
100
The state of the first function is left untouched:
>>> print(state.get_value())
10
We now create a copy with updates removed using the ``delete_updates``
parameter, which is set to ``False`` by default:
>>> null_accumulator = accumulator.copy(delete_updates=True)
As expected, the shared state is no longer updated:
>>> null_accumulator(9000)
[array(10)]
>>> print(state.get_value())
10
.. _using_random_numbers:
Using Random Numbers Using Random Numbers
==================== ====================
......
...@@ -714,7 +714,13 @@ class Function(object): ...@@ -714,7 +714,13 @@ class Function(object):
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy, f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=profile, mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input, # When removing updates containing variables
# not used in the output function, copy
# generates an unused implicit input.
# We ignore the resulting errors,
# but could change it to 'warn' if this might
# cause problems.
on_unused_input='ignore',
function_builder=maker.function_builder, function_builder=maker.function_builder,
# As this is an optimized graph, it # As this is an optimized graph, it
# can contain inplace. DebugMode check # can contain inplace. DebugMode check
......
...@@ -366,6 +366,7 @@ class T_function(unittest.TestCase): ...@@ -366,6 +366,7 @@ class T_function(unittest.TestCase):
assert in1.value is in2.value assert in1.value is in2.value
def test_copy_delete_updates(self): def test_copy_delete_updates(self):
w = T.iscalar('w')
x = T.fscalar('x') x = T.fscalar('x')
# SharedVariable for tests, one of them has update # SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y') y = theano.shared(value=1, name='y')
...@@ -383,6 +384,15 @@ class T_function(unittest.TestCase): ...@@ -383,6 +384,15 @@ class T_function(unittest.TestCase):
assert cpy(1)[0] == 4 assert cpy(1)[0] == 4
assert cpy(1)[0] == 4 assert cpy(1)[0] == 4
# Test if unused implicit and explicit inputs from delete_updates
# are ignored as intended.
for mode in ["FAST_RUN", "FAST_COMPILE"]:
ori = theano.function([x], x, mode=mode, updates={z: z * 2})
cpy = ori.copy(delete_updates=True)
ori = theano.function([x, w], x, mode=mode, updates={z: z + w})
cpy = ori.copy(delete_updates=True)
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论