提交 03281b5f authored 作者: slefrancois's avatar slefrancois

Added tutorial for function.copy in example.txt.

Changed function.copy to ignore unused inputs after deleting updates. Added test to make sure copy ignores unused inputs.
上级 b96e3954
...@@ -274,9 +274,6 @@ In practice, a good way of thinking about the ``givens`` is as a mechanism ...@@ -274,9 +274,6 @@ In practice, a good way of thinking about the ``givens`` is as a mechanism
that allows you to replace any part of your formula with a different that allows you to replace any part of your formula with a different
expression that evaluates to a tensor of same shape and dtype. expression that evaluates to a tensor of same shape and dtype.
.. _using_random_numbers:
.. note:: .. note::
Theano shared variable broadcast pattern default to False for each Theano shared variable broadcast pattern default to False for each
...@@ -285,6 +282,54 @@ expression that evaluates to a tensor of same shape and dtype. ...@@ -285,6 +282,54 @@ expression that evaluates to a tensor of same shape and dtype.
different pattern, just pass it as a parameter different pattern, just pass it as a parameter
``theano.shared(..., broadcastable=(True, False))`` ``theano.shared(..., broadcastable=(True, False))``
Copying functions
=================
Theano functions can be copied, which can be useful for creating similar
functions but with different shared variables or updates. This is done using
the :func:`copy()<theano.compile.function_module.Function.copy>` method of ``function`` objects. The optimized graph of the original function is copied,
so compilation only needs to be performed once.
Let's start from the accumulator defined above:
>>> import theano
>>> import theano.tensor as T
>>> state = theano.shared(0)
>>> inc = T.iscalar('inc')
>>> accumulator = theano.function([inc], state, updates=[(state, state+inc)])
We can use it to increment the state as usual:
>>> accumulator(10)
>>> print(state.get_value())
10
We can use ``copy()`` to create a similar accumulator but with its own internal state
using the ``swap`` parameter, which is a dictionary of shared variables to exchange:
>>> new_state = theano.shared(0)
>>> new_accumulator = accumulator.copy(swap={state:new_state})
>>> new_accumulator(100)
>>> print(new_state.get_value())
100
The state of the first function is left untouched:
>>> print(state.get_value())
10
We now create a copy with updates removed using the ``delete_updates``
parameter, which is set to ``False`` by default:
>>> null_accumulator = accumulator.copy(delete_updates=True)
As expected, the shared state is no longer updated:
>>> null_accumulator(9000)
>>> print(state.get_value())
10
.. _using_random_numbers:
Using Random Numbers Using Random Numbers
==================== ====================
......
...@@ -714,7 +714,13 @@ class Function(object): ...@@ -714,7 +714,13 @@ class Function(object):
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy, f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=profile, mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input, # When removing updates containing variables
# not used in the output function, copy
# generates an unused implicit input.
# We ignore the resulting errors,
# but could change it to 'warn' if this might
# cause problems.
on_unused_input='ignore',
function_builder=maker.function_builder, function_builder=maker.function_builder,
# As this is an optimized graph, it # As this is an optimized graph, it
# can contain inplace. DebugMode check # can contain inplace. DebugMode check
......
...@@ -383,6 +383,13 @@ class T_function(unittest.TestCase): ...@@ -383,6 +383,13 @@ class T_function(unittest.TestCase):
assert cpy(1)[0] == 4 assert cpy(1)[0] == 4
assert cpy(1)[0] == 4 assert cpy(1)[0] == 4
# Test if unused implicit inputs from delete_updates are ignored
# as intended.
for mode in ["FAST_RUN", "FAST_COMPILE"]:
ori = theano.function([x], x, mode=mode, updates={z: z * 2})
cpy = ori.copy(delete_updates=True)
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论