提交 6d39faaf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove the scan_ prefix from modules and relocate them to theano.scan

The one drawback to this change is that `theano.scan`, i.e. the `scan` function exposed by `theano/__init__.py`, shadows the `theano.scan` sub-package. In other words, modules in the `theano.scan` sub-package cannot be accessed via the top-level `theano` module object (e.g. `import theano; theano.scan.op.Scan` won't work). This is a minor inconvenience, and, since internal library code is generally expected to import objects directly from the modules in which they're defined, the appearance of this problem will serve as a welcome warning.
上级 77b87e95
...@@ -46,24 +46,24 @@ Relevant code files ...@@ -46,24 +46,24 @@ Relevant code files
=================== ===================
The implementation of Scan is spread over several files in The implementation of Scan is spread over several files in
``theano/scan_module``. The different files, and sections of the code they ``theano/scan``. The different files, and sections of the code they
deal with, are : deal with, are :
* ``scan.py`` implements the ``scan`` function. The ``scan`` function * ``basic.py`` implements the ``scan`` function. The ``scan`` function
arranges the arguments of scan correctly, constructs the scan op and arranges the arguments of scan correctly, constructs the scan op and
afterwards calls the constructed scan op on the arguments. This function afterwards calls the constructed scan op on the arguments. This function
takes care of figuring out missing inputs and shared variables. takes care of figuring out missing inputs and shared variables.
* ``scan_op.py`` implements the ``Scan`` op class. The ``Scan`` respects * ``op.py`` implements the ``Scan`` op class. The ``Scan`` respects
the ``Op`` interface, and contains most of the logic of the scan operator. the ``Op`` interface, and contains most of the logic of the scan operator.
* ``scan_utils.py`` contains several helpful functions used throughout out the * ``utils.py`` contains several helpful functions used throughout out the
other files that are specific of the scan operator. other files that are specific of the scan operator.
* ``scan_views.py`` contains different views of the scan op that have * ``views.py`` contains different views of the scan op that have
simpler and easier signatures to be used in specific cases. simpler and easier signatures to be used in specific cases.
* ``scan_opt.py`` contains the list of all Theano graph optimizations for the * ``opt.py`` contains the list of all Theano graph optimizations for the
scan operator. scan operator.
...@@ -269,14 +269,14 @@ Because of the complexity involved in dealing with Scan, a large number of ...@@ -269,14 +269,14 @@ Because of the complexity involved in dealing with Scan, a large number of
helper classes and functions have been developped over time to implement helper classes and functions have been developped over time to implement
operations commonly needed when dealing with the scan op. The scan op operations commonly needed when dealing with the scan op. The scan op
itself defines a large number of them and others can be found in the file itself defines a large number of them and others can be found in the file
``scan_utils.py``. This sections aims to point out the most useful ones sorted ``utils.py``. This sections aims to point out the most useful ones sorted
by usage. by usage.
Accessing/manipulating Scan's inputs and outputs by type Accessing/manipulating Scan's inputs and outputs by type
-------------------------------------------------------- --------------------------------------------------------
Declared in ``scan_utils.py``, the class ``scan_args`` handles the Declared in ``utils.py``, the class ``scan_args`` handles the
parsing of the inputs and outputs (both inner and outer) to a format parsing of the inputs and outputs (both inner and outer) to a format
that is easier to analyse and manipulate. Without this class, that is easier to analyse and manipulate. Without this class,
analysing Scan's inputs and outputs often required convoluted logic analysing Scan's inputs and outputs often required convoluted logic
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
OpFromGraph OpFromGraph
=========== ===========
This page describes :class:`theano.OpFromGraph This page describes :class:`theano.compile.builders.OpFromGraph
<theano.compile.builders.OpFromGraph>`, an Op that allows to <theano.compile.builders.OpFromGraph>`, an Op that allows to
encapsulate a Theano graph in an op. encapsulate a Theano graph in an op.
......
...@@ -391,7 +391,7 @@ ...@@ -391,7 +391,7 @@
"source": [ "source": [
"x, y, z = tt.scalars('xyz')\n", "x, y, z = tt.scalars('xyz')\n",
"e = tt.nnet.sigmoid((x + y + z)**2)\n", "e = tt.nnet.sigmoid((x + y + z)**2)\n",
"op = th.OpFromGraph([x, y, z], [e])\n", "op = th.compile.builders.OpFromGraph([x, y, z], [e])\n",
"\n", "\n",
"e2 = op(x, y, z) + op(z, y, x)\n", "e2 = op(x, y, z) + op(z, y, x)\n",
"f = th.function([x, y, z], e2)" "f = th.function([x, y, z], e2)"
...@@ -436,9 +436,9 @@ ...@@ -436,9 +436,9 @@
"source": [ "source": [
"x, y, z = tt.scalars('xyz')\n", "x, y, z = tt.scalars('xyz')\n",
"e = x * y\n", "e = x * y\n",
"op = th.OpFromGraph([x, y], [e])\n", "op = th.compile.builders.OpFromGraph([x, y], [e])\n",
"e2 = op(x, y) + z\n", "e2 = op(x, y) + z\n",
"op2 = th.OpFromGraph([x, y, z], [e2])\n", "op2 = th.compile.builders.OpFromGraph([x, y, z], [e2])\n",
"e3 = op2(x, y, z) + z\n", "e3 = op2(x, y, z) + z\n",
"f = th.function([x, y, z], [e3])" "f = th.function([x, y, z], [e3])"
] ]
......
...@@ -220,7 +220,7 @@ defines a nested graph, which will be visualized accordingly by ...@@ -220,7 +220,7 @@ defines a nested graph, which will be visualized accordingly by
x, y, z = tt.scalars('xyz') x, y, z = tt.scalars('xyz')
e = tt.nnet.sigmoid((x + y + z)**2) e = tt.nnet.sigmoid((x + y + z)**2)
op = th.OpFromGraph([x, y, z], [e]) op = th.compile.builders.OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
f = th.function([x, y, z], e2) f = th.function([x, y, z], e2)
...@@ -249,9 +249,9 @@ the following example. ...@@ -249,9 +249,9 @@ the following example.
x, y, z = tt.scalars('xyz') x, y, z = tt.scalars('xyz')
e = x * y e = x * y
op = th.OpFromGraph([x, y], [e]) op = th.compile.builders.OpFromGraph([x, y], [e])
e2 = op(x, y) + z e2 = op(x, y) + z
op2 = th.OpFromGraph([x, y, z], [e2]) op2 = th.compile.builders.OpFromGraph([x, y, z], [e2])
e3 = op2(x, y, z) + z e3 = op2(x, y, z) + z
f = th.function([x, y, z], [e3]) f = th.function([x, y, z], [e3])
......
...@@ -52,7 +52,7 @@ list of ops that support R-op: ...@@ -52,7 +52,7 @@ list of ops that support R-op:
* Reshape * Reshape
* Flatten * Flatten
* DimShuffle * DimShuffle
* Scan [In scan_module/tests/test_scan.test_rop] * Scan [In scan/tests/test_scan.test_rop]
* without test * without test
* Split * Split
......
...@@ -539,7 +539,7 @@ value ``max_value``. ...@@ -539,7 +539,7 @@ value ``max_value``.
.. testcode:: .. testcode::
def power_of_2(previous_power, max_value): def power_of_2(previous_power, max_value):
return previous_power*2, theano.scan_module.until(previous_power*2 > max_value) return previous_power*2, theano.scan.utils.until(previous_power*2 > max_value)
max_value = tt.scalar() max_value = tt.scalar()
values, _ = theano.scan(power_of_2, values, _ = theano.scan(power_of_2,
...@@ -557,7 +557,7 @@ value ``max_value``. ...@@ -557,7 +557,7 @@ value ``max_value``.
As you can see, in order to terminate on condition, the only thing required As you can see, in order to terminate on condition, the only thing required
is that the inner function ``power_of_2`` to return also the condition is that the inner function ``power_of_2`` to return also the condition
wrapped in the class ``theano.scan_module.until``. The condition has to be wrapped in the class ``theano.scan.utils.until``. The condition has to be
expressed in terms of the arguments of the inner function (in this case expressed in terms of the arguments of the inner function (in this case
``previous_power`` and ``max_value``). ``previous_power`` and ``max_value``).
...@@ -675,7 +675,7 @@ higher memory usage. ...@@ -675,7 +675,7 @@ higher memory usage.
reference reference
========= =========
.. automodule:: theano.scan_module .. automodule:: theano.scan
.. autofunction:: theano.map .. autofunction:: theano.map
.. autofunction:: theano.reduce .. autofunction:: theano.reduce
......
...@@ -32,9 +32,9 @@ class OfgNested: ...@@ -32,9 +32,9 @@ class OfgNested:
def __init__(self): def __init__(self):
x, y, z = tt.scalars("xyz") x, y, z = tt.scalars("xyz")
e = x * y e = x * y
op = theano.OpFromGraph([x, y], [e]) op = theano.compile.builders.OpFromGraph([x, y], [e])
e2 = op(x, y) + z e2 = op(x, y) + z
op2 = theano.OpFromGraph([x, y, z], [e2]) op2 = theano.compile.builders.OpFromGraph([x, y, z], [e2])
e3 = op2(x, y, z) + z e3 = op2(x, y, z) + z
self.inputs = [x, y, z] self.inputs = [x, y, z]
...@@ -45,7 +45,7 @@ class Ofg: ...@@ -45,7 +45,7 @@ class Ofg:
def __init__(self): def __init__(self):
x, y, z = tt.scalars("xyz") x, y, z = tt.scalars("xyz")
e = tt.nnet.sigmoid((x + y + z) ** 2) e = tt.nnet.sigmoid((x + y + z) ** 2)
op = theano.OpFromGraph([x, y, z], [e]) op = theano.compile.builders.OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
self.inputs = [x, y, z] self.inputs = [x, y, z]
...@@ -56,7 +56,7 @@ class OfgSimple: ...@@ -56,7 +56,7 @@ class OfgSimple:
def __init__(self): def __init__(self):
x, y, z = tt.scalars("xyz") x, y, z = tt.scalars("xyz")
e = tt.nnet.sigmoid((x + y + z) ** 2) e = tt.nnet.sigmoid((x + y + z) ** 2)
op = theano.OpFromGraph([x, y, z], [e]) op = theano.compile.builders.OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) e2 = op(x, y, z)
self.inputs = [x, y, z] self.inputs = [x, y, z]
......
...@@ -2,29 +2,22 @@ import numpy as np ...@@ -2,29 +2,22 @@ import numpy as np
import pytest import pytest
import theano import theano
import theano.gpuarray
import theano.tensor as tt import theano.tensor as tt
from theano.scan.basic import scan
from theano.scan.checkpoints import scan_checkpoints
try:
from pygpu.gpuarray import GpuArrayException
PYGPU_AVAILABLE = True
except ImportError:
PYGPU_AVAILABLE = False
class TestScanCheckpoint: class TestScanCheckpoint:
def setup_method(self): def setup_method(self):
self.k = tt.iscalar("k") self.k = tt.iscalar("k")
self.A = tt.vector("A") self.A = tt.vector("A")
result, _ = theano.scan( result, _ = scan(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=tt.ones_like(self.A), outputs_info=tt.ones_like(self.A),
non_sequences=self.A, non_sequences=self.A,
n_steps=self.k, n_steps=self.k,
) )
result_check, _ = theano.scan_checkpoints( result_check, _ = scan_checkpoints(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=tt.ones_like(self.A), outputs_info=tt.ones_like(self.A),
non_sequences=self.A, non_sequences=self.A,
...@@ -52,34 +45,7 @@ class TestScanCheckpoint: ...@@ -52,34 +45,7 @@ class TestScanCheckpoint:
out, out_check = f(range(10), 101) out, out_check = f(range(10), 101)
assert np.allclose(out, out_check) assert np.allclose(out, out_check)
@pytest.mark.skipif(~PYGPU_AVAILABLE, reason="Requires pygpu.")
@pytest.mark.skipif(
None not in theano.gpuarray.type.list_contexts(),
reason="Requires gpuarray backend.",
)
def test_memory(self):
from tests.gpuarray.config import mode_with_gpu # noqa
f = theano.function(
inputs=[self.A, self.k], outputs=self.grad_A, mode=mode_with_gpu
)
f_check = theano.function(
inputs=[self.A, self.k], outputs=self.grad_A_check, mode=mode_with_gpu
)
free_gmem = theano.gpuarray.type._context_reg[None].free_gmem
data = np.ones(free_gmem // 3000, dtype=np.float32)
# Check that it works with the checkpoints
size = 1000
if isinstance(mode_with_gpu, theano.compile.DebugMode):
size = 100
f_check(data, size)
# Check that the basic scan fails in that case
# Skip that check in DebugMode, as it can fail in different ways
if not isinstance(mode_with_gpu, theano.compile.DebugMode):
with pytest.raises(GpuArrayException):
f(data, 1000)
def test_taps_error(self): def test_taps_error(self):
# Test that an error rises if we use taps in outputs_info. # Test that an error rises if we use taps in outputs_info.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
theano.scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]}) scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]})
...@@ -4,7 +4,7 @@ import theano ...@@ -4,7 +4,7 @@ import theano
import theano.tensor as tt import theano.tensor as tt
from tests import unittest_tools as utt from tests import unittest_tools as utt
from theano import config from theano import config
from theano.scan_module.scan_op import Scan from theano.scan.op import Scan
mode = theano.compile.mode.get_mode(config.mode) mode = theano.compile.mode.get_mode(config.mode)
......
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import theano import theano
from theano import tensor from theano import tensor
from theano.scan_module.scan_utils import map_variables from theano.scan.utils import map_variables
class TestMapVariables: class TestMapVariables:
...@@ -130,7 +130,7 @@ class TestMapVariables: ...@@ -130,7 +130,7 @@ class TestMapVariables:
# construct the outer graph # construct the outer graph
c = tensor.scalar() c = tensor.scalar()
d = tensor.scalar() d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d) u = theano.compile.builders.OpFromGraph([a, b], [r])(c, d)
t = z * u t = z * u
(v,) = map_variables(self.replacer, [t]) (v,) = map_variables(self.replacer, [t])
t2 = z * v t2 = z * v
......
...@@ -79,7 +79,6 @@ from theano import scalar, tensor ...@@ -79,7 +79,6 @@ from theano import scalar, tensor
from theano.compile import ( from theano.compile import (
In, In,
Mode, Mode,
OpFromGraph,
Out, Out,
Param, Param,
ProfileStats, ProfileStats,
...@@ -194,4 +193,4 @@ def sparse_grad(var): ...@@ -194,4 +193,4 @@ def sparse_grad(var):
import theano.tensor.shared_randomstreams import theano.tensor.shared_randomstreams
from theano.scan_module import clone, foldl, foldr, map, reduce, scan, scan_checkpoints from theano.scan import checkpoints, clone, foldl, foldr, map, reduce, scan
from theano.compile.builders import OpFromGraph, ops_with_inner_function
from theano.compile.debugmode import DebugMode from theano.compile.debugmode import DebugMode
from theano.compile.function.pfunc import Param, pfunc, rebuild_collect_shared from theano.compile.function.pfunc import Param, pfunc, rebuild_collect_shared
from theano.compile.function.types import ( from theano.compile.function.types import (
......
...@@ -10,12 +10,69 @@ from theano.compile.function.types import orig_function ...@@ -10,12 +10,69 @@ from theano.compile.function.types import orig_function
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
from theano.gof import Variable, ops_with_inner_function from theano.gof import Variable, ops_with_inner_function
from theano.gof.fg import FunctionGraph
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.null_type import NullType from theano.gof.null_type import NullType
from theano.gof.op import Op
from theano.gof.opt import in2out, local_optimizer
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor.opt import ShapeFeature
class OpFromGraph(gof.Op): def infer_shape(outs, inputs, input_shapes):
"""
Compute the shape of the outputs given the shape of the inputs of a theano
graph.
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually
for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.ndim:
assert len(inp_shp) == inp.ndim
shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))
# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
shape_feature.set_shape(inp, inp_shp)
def local_traverse(out):
"""
Go back in the graph, from out, adding computable shapes to shape_of.
"""
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else:
# Recurse over inputs
for inp in out.owner.inputs:
if inp not in shape_feature.shape_of:
local_traverse(inp)
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
local_traverse(o)
ret.append(shape_feature.shape_of[o])
return ret
class OpFromGraph(Op):
r""" r"""
This creates an ``Op`` from inputs and outputs lists of variables. This creates an ``Op`` from inputs and outputs lists of variables.
The signature is similar to :func:`theano.function <theano.function>` The signature is similar to :func:`theano.function <theano.function>`
...@@ -28,9 +85,9 @@ class OpFromGraph(gof.Op): ...@@ -28,9 +85,9 @@ class OpFromGraph(gof.Op):
Parameters Parameters
---------- ----------
inputs: list of :class:`Variable <theano.gof.Variable>` inputs: list of :class:`Variable <theano.gof.graph.Variable>`
outputs: list of :class:`Variable <theano.gof.Variable>` outputs: list of :class:`Variable <theano.gof.graph.Variable>`
inline: bool, optional inline: bool, optional
Defaults to ``False`` Defaults to ``False``
...@@ -52,15 +109,15 @@ class OpFromGraph(gof.Op): ...@@ -52,15 +109,15 @@ class OpFromGraph(gof.Op):
arguments as one would specify in grad() method. arguments as one would specify in grad() method.
callable : Should take two args: ``inputs`` and ``output_grads``. callable : Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`. Each argument is expected to be a list of :class:`Variable <theano.gof.graph.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`. Must return list of :class:`Variable <theano.gof.graph.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable ``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero ``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of :class:`Variable <theano.gof.graph.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs. a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
...@@ -74,15 +131,15 @@ class OpFromGraph(gof.Op): ...@@ -74,15 +131,15 @@ class OpFromGraph(gof.Op):
arguments as one would specify in grad() method. arguments as one would specify in grad() method.
callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``. callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`. Each argument is expected to be a list of :class:`Variable <theano.gof.graph.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`. Must return list of :class:`Variable <theano.gof.graph.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable ``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero ``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of :class:`Variable <theano.gof.graph.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs. a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
...@@ -95,15 +152,15 @@ class OpFromGraph(gof.Op): ...@@ -95,15 +152,15 @@ class OpFromGraph(gof.Op):
arguments as one would specify in R_op() method. arguments as one would specify in R_op() method.
callable : Should take two args: ``inputs`` and ``eval_points``. callable : Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`. Each argument is expected to be a list of :class:`Variable <theano.gof.graph.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`. Must return list of :class:`Variable <theano.gof.graph.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable ``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op ``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds :class:`Variable <theano.gof.graph.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs. to a specific output of R_op, length of list must be equal to number of outputs.
connection_pattern : list of list connection_pattern : list of list
...@@ -158,7 +215,8 @@ class OpFromGraph(gof.Op): ...@@ -158,7 +215,8 @@ class OpFromGraph(gof.Op):
.. code-block:: python .. code-block:: python
from theano import function, OpFromGraph, tensor from theano import function, tensor
from theano.compile.builders import OpFromGraph
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
...@@ -172,7 +230,9 @@ class OpFromGraph(gof.Op): ...@@ -172,7 +230,9 @@ class OpFromGraph(gof.Op):
import numpy as np import numpy as np
import theano import theano
from theano import config, function, OpFromGraph, tensor from theano import config, function, tensor
from theano.compile.builders import OpFromGraph
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
s = theano.shared(np.random.rand(2, 2).astype(config.floatX)) s = theano.shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
...@@ -185,7 +245,9 @@ class OpFromGraph(gof.Op): ...@@ -185,7 +245,9 @@ class OpFromGraph(gof.Op):
.. code-block:: python .. code-block:: python
from theano import function, OpFromGraph, tensor, grad from theano import function, tensor, grad
from theano.compile.builders import OpFromGraph
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y * z e = x + y * z
def rescale_dy(inps, grads): def rescale_dy(inps, grads):
...@@ -718,9 +780,8 @@ class OpFromGraph(gof.Op): ...@@ -718,9 +780,8 @@ class OpFromGraph(gof.Op):
return list(map(list, cpmat_self)) return list(map(list, cpmat_self))
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(
self.local_outputs, self.local_inputs, shapes out_shp = infer_shape(self.local_outputs, self.local_inputs, shapes)
)
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
# Note: # Note:
...@@ -756,7 +817,7 @@ class OpFromGraph(gof.Op): ...@@ -756,7 +817,7 @@ class OpFromGraph(gof.Op):
output[0] = variable.copy() output[0] = variable.copy()
@gof.local_optimizer([OpFromGraph]) @local_optimizer([OpFromGraph])
def inline_ofg_expansion(node): def inline_ofg_expansion(node):
""" """
This optimization expands internal graph of OpFromGraph. This optimization expands internal graph of OpFromGraph.
...@@ -777,7 +838,7 @@ def inline_ofg_expansion(node): ...@@ -777,7 +838,7 @@ def inline_ofg_expansion(node):
# and before the first scan optimizer. # and before the first scan optimizer.
optdb.register( optdb.register(
"inline_ofg_expansion", "inline_ofg_expansion",
gof.opt.in2out(inline_ofg_expansion), in2out(inline_ofg_expansion),
-0.01, -0.01,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
......
...@@ -682,7 +682,9 @@ def debugprint( ...@@ -682,7 +682,9 @@ def debugprint(
new_prefix_child = prefix_child + " " new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"): if hasattr(i, "owner") and hasattr(i.owner, "op"):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): from theano.scan.op import Scan
if isinstance(i.owner.op, Scan):
scan_ops.append(i) scan_ops.append(i)
debugprint( debugprint(
......
...@@ -1386,7 +1386,7 @@ class FunctionMaker: ...@@ -1386,7 +1386,7 @@ class FunctionMaker:
try: try:
with open(graph_db_file, "rb") as f: with open(graph_db_file, "rb") as f:
# Temporary hack to allow # Temporary hack to allow
# tests.scan_module.test_scan.T_Scan to # tests.scan.test_scan.T_Scan to
# finish. Should be changed in definitive version. # finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function tmp = theano.config.unpickle_function
theano.config.unpickle_function = False theano.config.unpickle_function = False
......
...@@ -8,8 +8,6 @@ import logging ...@@ -8,8 +8,6 @@ import logging
from theano import gof from theano import gof
from .sharedvalue import SharedVariable
_logger = logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
...@@ -211,6 +209,8 @@ class In(SymbolicInput): ...@@ -211,6 +209,8 @@ class In(SymbolicInput):
) )
if implicit is None: if implicit is None:
from theano.compile.sharedvalue import SharedVariable
implicit = isinstance(value, gof.Container) or isinstance( implicit = isinstance(value, gof.Container) or isinstance(
value, SharedVariable value, SharedVariable
) )
......
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
Provide a simple user friendly API to Theano-managed memory. Provide a simple user friendly API to Theano-managed memory.
""" """
# Standard imports
import copy import copy
import logging import logging
# Third-party imports
import numpy as np import numpy as np
# Theano imports from theano.gof.graph import Variable
from theano.gof import Container, Variable, generic, utils from theano.gof.link import Container
from theano.gof.type import generic
from theano.gof.utils import add_tag_trace
_logger = logging.getLogger("theano.compile.sharedvalue") _logger = logging.getLogger("theano.compile.sharedvalue")
...@@ -287,7 +286,7 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -287,7 +286,7 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
allow_downcast=allow_downcast, allow_downcast=allow_downcast,
**kwargs, **kwargs,
) )
utils.add_tag_trace(var) add_tag_trace(var)
return var return var
except TypeError: except TypeError:
continue continue
......
...@@ -533,11 +533,9 @@ class ReplaceValidate(History, Validator): ...@@ -533,11 +533,9 @@ class ReplaceValidate(History, Validator):
if verbose is None: if verbose is None:
verbose = config.optimizer_verbose verbose = config.optimizer_verbose
if config.scan.debug: if config.scan.debug:
scans = [ from theano.scan.op import Scan
n
for n in fgraph.apply_nodes scans = [n for n in fgraph.apply_nodes if isinstance(n.op, Scan)]
if isinstance(n.op, theano.scan_module.scan_op.Scan)
]
for r, new_r in replacements: for r, new_r in replacements:
try: try:
...@@ -581,11 +579,9 @@ class ReplaceValidate(History, Validator): ...@@ -581,11 +579,9 @@ class ReplaceValidate(History, Validator):
) )
raise raise
if config.scan.debug: if config.scan.debug:
scans2 = [ from theano.scan.op import Scan
n
for n in fgraph.apply_nodes scans2 = [n for n in fgraph.apply_nodes if isinstance(n.op, Scan)]
if isinstance(n.op, theano.scan_module.scan_op.Scan)
]
nb = len(scans) nb = len(scans)
nb2 = len(scans2) nb2 = len(scans2)
if nb2 > nb: if nb2 > nb:
......
...@@ -103,8 +103,8 @@ def add_tag_trace(thing, user_line=None): ...@@ -103,8 +103,8 @@ def add_tag_trace(thing, user_line=None):
"theano\\scalar\\basic.py", "theano\\scalar\\basic.py",
"theano/sandbox/", "theano/sandbox/",
"theano\\sandbox\\", "theano\\sandbox\\",
"theano/scan_module/", "theano/scan/",
"theano\\scan_module\\", "theano\\scan\\",
"theano/sparse/", "theano/sparse/",
"theano\\sparse\\", "theano\\sparse\\",
"theano/typed_list/", "theano/typed_list/",
......
...@@ -153,7 +153,9 @@ from theano.ifelse import IfElse ...@@ -153,7 +153,9 @@ from theano.ifelse import IfElse
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from theano.scalar.basic import Cast, Pow, Scalar, log, neg, true_div from theano.scalar.basic import Cast, Pow, Scalar, log, neg, true_div
from theano.scalar.basic_scipy import Erfcinv, Erfinv from theano.scalar.basic_scipy import Erfcinv, Erfinv
from theano.scan_module import scan_op, scan_opt, scan_utils from theano.scan import utils
from theano.scan.op import Scan
from theano.scan.opt import ScanInplaceOptimizer
from theano.tensor.nnet import bn, conv3d2d from theano.tensor.nnet import bn, conv3d2d
from theano.tensor.nnet.abstract_conv import ( from theano.tensor.nnet.abstract_conv import (
AbstractConv2d, AbstractConv2d,
...@@ -2600,13 +2602,13 @@ def gpu_reconstruct_graph(inputs, outputs, tag=None): ...@@ -2600,13 +2602,13 @@ def gpu_reconstruct_graph(inputs, outputs, tag=None):
givens = {} givens = {}
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
nw_outputs = scan_utils.clone(outputs, replace=givens) nw_outputs = utils.clone(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
@register_opt("scan", "fast_compile") @register_opt("scan", "fast_compile")
@op_lifter([scan_op.Scan]) @op_lifter([Scan])
@register_opt2([scan_op.Scan], "fast_compile") @register_opt2([Scan], "fast_compile")
def local_gpua_scan_to_gpua(op, context_name, inputs, outputs): def local_gpua_scan_to_gpua(op, context_name, inputs, outputs):
info = copy.deepcopy(op.info) info = copy.deepcopy(op.info)
if info.get("gpua", False): if info.get("gpua", False):
...@@ -2628,7 +2630,7 @@ def local_gpua_scan_to_gpua(op, context_name, inputs, outputs): ...@@ -2628,7 +2630,7 @@ def local_gpua_scan_to_gpua(op, context_name, inputs, outputs):
scan_outs += [op.outputs[-1]] scan_outs += [op.outputs[-1]]
else: else:
scan_outs = [safe_to_gpu(x, context_name) for x in op.outputs] scan_outs = [safe_to_gpu(x, context_name) for x in op.outputs]
scan_outs = scan_utils.clone( scan_outs = utils.clone(
scan_outs, replace=list(zip(op.inputs, (safe_to_cpu(x) for x in scan_ins))) scan_outs, replace=list(zip(op.inputs, (safe_to_cpu(x) for x in scan_ins)))
) )
...@@ -2645,9 +2647,9 @@ def local_gpua_scan_to_gpua(op, context_name, inputs, outputs): ...@@ -2645,9 +2647,9 @@ def local_gpua_scan_to_gpua(op, context_name, inputs, outputs):
dtype=dtype, broadcastable=broadcastable, context_name=context_name dtype=dtype, broadcastable=broadcastable, context_name=context_name
) )
nw_op = scan_op.Scan( nw_op = Scan(scan_ins, scan_outs, info, typeConstructor=typebuild).make_node(
scan_ins, scan_outs, info, typeConstructor=typebuild *nw_ins
).make_node(*nw_ins) )
return nw_op.outputs return nw_op.outputs
...@@ -2916,7 +2918,7 @@ def local_gpu_ctc(op, context_name, inputs, outputs): ...@@ -2916,7 +2918,7 @@ def local_gpu_ctc(op, context_name, inputs, outputs):
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
optdb.register( optdb.register(
"gpua_scanOp_make_inplace", "gpua_scanOp_make_inplace",
scan_opt.ScanInplaceOptimizer(typeInfer=_scan_type_infer, gpua_flag=True), ScanInplaceOptimizer(typeInfer=_scan_type_infer, gpua_flag=True),
75, 75,
"gpuarray", "gpuarray",
"inplace", "inplace",
......
...@@ -1777,8 +1777,8 @@ def verify_grad( ...@@ -1777,8 +1777,8 @@ def verify_grad(
Notes Notes
----- -----
This function does not support multiple outputs. In This function does not support multiple outputs. In
tests/test_scan.py there is an experimental verify_grad that tests/scan/test_basic.py there is an experimental `verify_grad` that covers
covers that case as well by using random projections. that case as well by using random projections.
""" """
......
...@@ -20,7 +20,7 @@ import theano.tensor ...@@ -20,7 +20,7 @@ import theano.tensor
from theano import gof from theano import gof
from theano.compile import optdb from theano.compile import optdb
from theano.gof import Apply, Op from theano.gof import Apply, Op
from theano.scan_module.scan_utils import clone from theano.scan.utils import clone
from theano.tensor import TensorType, opt from theano.tensor import TensorType, opt
......
...@@ -120,6 +120,8 @@ def debugprint( ...@@ -120,6 +120,8 @@ def debugprint(
to the Apply's identifier, to indicate which output a line corresponds to. to the Apply's identifier, to indicate which output a line corresponds to.
""" """
from theano.scan.op import Scan
if not isinstance(depth, int): if not isinstance(depth, int):
raise Exception("depth parameter must be an int") raise Exception("depth parameter must be an int")
if file == "str": if file == "str":
...@@ -202,9 +204,7 @@ N.B.: ...@@ -202,9 +204,7 @@ N.B.:
for r, p, s, o in zip(results_to_print, profile_list, smap, order): for r, p, s, o in zip(results_to_print, profile_list, smap, order):
# Add the parent scan op to the list as well # Add the parent scan op to the list as well
if hasattr(r.owner, "op") and isinstance( if hasattr(r.owner, "op") and isinstance(r.owner.op, Scan):
r.owner.op, theano.scan_module.scan_op.Scan
):
scan_ops.append(r) scan_ops.append(r)
debugmode.debugprint( debugmode.debugprint(
...@@ -265,7 +265,7 @@ N.B.: ...@@ -265,7 +265,7 @@ N.B.:
for idx, i in enumerate(outputs): for idx, i in enumerate(outputs):
if hasattr(i, "owner") and hasattr(i.owner, "op"): if hasattr(i, "owner") and hasattr(i.owner, "op"):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): if isinstance(i.owner.op, Scan):
scan_ops.append(i) scan_ops.append(i)
debugmode.debugprint( debugmode.debugprint(
...@@ -804,6 +804,8 @@ def pydotprint( ...@@ -804,6 +804,8 @@ def pydotprint(
scan separately after the top level debugprint output. scan separately after the top level debugprint output.
""" """
from theano.scan.op import Scan
if colorCodes is None: if colorCodes is None:
colorCodes = default_colorCodes colorCodes = default_colorCodes
...@@ -1119,11 +1121,7 @@ def pydotprint( ...@@ -1119,11 +1121,7 @@ def pydotprint(
outfile += "." + format outfile += "." + format
if scan_graphs: if scan_graphs:
scan_ops = [ scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)]
(idx, x)
for idx, x in enumerate(topo)
if isinstance(x.op, theano.scan_module.scan_op.Scan)
]
path, fn = os.path.split(outfile) path, fn = os.path.split(outfile)
basename = ".".join(fn.split(".")[:-1]) basename = ".".join(fn.split(".")[:-1])
# Safe way of doing things .. a file name may contain multiple . # Safe way of doing things .. a file name may contain multiple .
......
...@@ -20,8 +20,8 @@ from theano.compile.ops import ( ...@@ -20,8 +20,8 @@ from theano.compile.ops import (
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
from theano.ifelse import IfElse from theano.ifelse import IfElse
from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp
from theano.scan_module.scan_op import Scan from theano.scan.op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs from theano.scan.utils import scan_args as ScanArgs
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
......
...@@ -5,7 +5,7 @@ Scanning is a general form of recurrence, which can be used for looping. ...@@ -5,7 +5,7 @@ Scanning is a general form of recurrence, which can be used for looping.
The idea is that you *scan* a function along some input sequence, producing The idea is that you *scan* a function along some input sequence, producing
an output at each time-step that can be seen (but not modified) by the an output at each time-step that can be seen (but not modified) by the
function at the next time-step. (Technically, the function can see the function at the next time-step. (Technically, the function can see the
previous K time-steps of your outputs and L time steps (from the past and previous K time-steps of your outputs and L time steps (from past and
future) of your inputs. future) of your inputs.
So for example, ``sum()`` could be computed by scanning the ``z+x_i`` So for example, ``sum()`` could be computed by scanning the ``z+x_i``
...@@ -13,15 +13,21 @@ function over a list, given an initial state of ``z=0``. ...@@ -13,15 +13,21 @@ function over a list, given an initial state of ``z=0``.
Special cases: Special cases:
* A *reduce* operation can be performed by returning only the last * A *reduce* operation can be performed by using only the last
output of a ``scan``. output of a ``scan``.
* A *map* operation can be performed by applying a function that * A *map* operation can be performed by applying a function that
ignores previous steps of the outputs. ignores previous steps of the outputs.
Often a for-loop can be expressed as a ``scan()`` operation, and ``scan`` is Often a for-loop or while-loop can be expressed as a ``scan()`` operation,
the closest that theano comes to looping. The advantage of using ``scan`` and ``scan`` is the closest that theano comes to looping. The advantages
over for loops is that it allows the number of iterations to be a part of of using ``scan`` over `for` loops in python (amongs other) are:
the symbolic graph.
* it allows the number of iterations to be part of the symbolic graph
* it allows computing gradients through the for loop
* there exist a bunch of optimizations that help re-write your loop
such that less memory is used and that it runs faster
* it ensures that data is not copied from host to gpu and gpu to
host at each step
The Scan Op should typically be used by calling any of the following The Scan Op should typically be used by calling any of the following
functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``, functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
...@@ -41,8 +47,8 @@ __authors__ = ( ...@@ -41,8 +47,8 @@ __authors__ = (
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>" __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
from theano.scan_module import scan_opt from theano.scan import opt
from theano.scan_module.scan import scan from theano.scan.basic import scan
from theano.scan_module.scan_checkpoints import scan_checkpoints from theano.scan.checkpoints import scan_checkpoints
from theano.scan_module.scan_utils import clone, until from theano.scan.utils import clone, until
from theano.scan_module.scan_views import foldl, foldr, map, reduce from theano.scan.views import foldl, foldr, map, reduce
"""
This module provides the Scan Op.
Scanning is a general form of recurrence, which can be used for looping.
The idea is that you *scan* a function along some input sequence, producing
an output at each time-step that can be seen (but not modified) by the
function at the next time-step. (Technically, the function can see the
previous K time-steps of your outputs and L time steps (from past and
future) of your inputs.
So for example, ``sum()`` could be computed by scanning the ``z+x_i``
function over a list, given an initial state of ``z=0``.
Special cases:
* A *reduce* operation can be performed by using only the last
output of a ``scan``.
* A *map* operation can be performed by applying a function that
ignores previous steps of the outputs.
Often a for-loop or while-loop can be expressed as a ``scan()`` operation,
and ``scan`` is the closest that theano comes to looping. The advantages
of using ``scan`` over `for` loops in python (amongs other) are:
* it allows the number of iterations to be part of the symbolic graph
* it allows computing gradients through the for loop
* there exist a bunch of optimizations that help re-write your loop
such that less memory is used and that it runs faster
* it ensures that data is not copied from host to gpu and gpu to
host at each step
The Scan Op should typically be used by calling any of the following
functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
``foldr()``.
"""
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
__authors__ = "Razvan Pascanu " "Frederic Bastien " "James Bergstra " "Pascal Lamblin " __authors__ = "Razvan Pascanu " "Frederic Bastien " "James Bergstra " "Pascal Lamblin "
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
...@@ -53,13 +15,14 @@ from theano.compile import SharedVariable, ops ...@@ -53,13 +15,14 @@ from theano.compile import SharedVariable, ops
from theano.compile.function import function from theano.compile.function import function
from theano.compile.mode import Mode from theano.compile.mode import Mode
from theano.gof.utils import TestValueError from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils from theano.scan import utils
from theano.scan_module.scan_utils import safe_new, traverse from theano.scan.op import Scan
from theano.scan.utils import safe_new, traverse
from theano.tensor import opt from theano.tensor import opt
from theano.updates import OrderedUpdates from theano.updates import OrderedUpdates
_logger = logging.getLogger("theano.scan_module.scan") _logger = logging.getLogger("theano.scan.basic")
def scan( def scan(
...@@ -167,7 +130,7 @@ def scan( ...@@ -167,7 +130,7 @@ def scan(
.. code-block:: python .. code-block:: python
... ...
return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) return [y1_t, y2_t], {x:x+1}, until(x < 50)
Note that a number of steps (considered in here as the maximum Note that a number of steps (considered in here as the maximum
number of steps ) is still required even though a condition is number of steps ) is still required even though a condition is
...@@ -576,7 +539,7 @@ def scan( ...@@ -576,7 +539,7 @@ def scan(
for seq in scan_seqs: for seq in scan_seqs:
lengths_vec.append(seq.shape[0]) lengths_vec.append(seq.shape[0])
if not scan_utils.isNaN_or_Inf_or_None(n_steps): if not utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered # ^ N_steps should also be considered
lengths_vec.append(tt.as_tensor(n_steps)) lengths_vec.append(tt.as_tensor(n_steps))
...@@ -591,7 +554,7 @@ def scan( ...@@ -591,7 +554,7 @@ def scan(
# If the user has provided the number of steps, do that regardless ( and # If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough ) # raise an error if the sequences are not long enough )
if scan_utils.isNaN_or_Inf_or_None(n_steps): if utils.isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0] actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]: for contestant in lengths_vec[1:]:
actual_n_steps = tt.minimum(actual_n_steps, contestant) actual_n_steps = tt.minimum(actual_n_steps, contestant)
...@@ -671,7 +634,7 @@ def scan( ...@@ -671,7 +634,7 @@ def scan(
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand_empty( utils.expand_empty(
tt.unbroadcast(tt.shape_padleft(actual_arg), 0), tt.unbroadcast(tt.shape_padleft(actual_arg), 0),
actual_n_steps, actual_n_steps,
) )
...@@ -695,7 +658,7 @@ def scan( ...@@ -695,7 +658,7 @@ def scan(
mit_sot_tap_array.append(init_out["taps"]) mit_sot_tap_array.append(init_out["taps"])
# Sequence # Sequence
mit_sot_scan_inputs.append( mit_sot_scan_inputs.append(
scan_utils.expand_empty(init_out["initial"][:mintap], actual_n_steps) utils.expand_empty(init_out["initial"][:mintap], actual_n_steps)
) )
if i in return_steps: if i in return_steps:
...@@ -783,7 +746,7 @@ def scan( ...@@ -783,7 +746,7 @@ def scan(
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args)) condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
if condition is not None: if condition is not None:
as_while = True as_while = True
else: else:
...@@ -838,7 +801,7 @@ def scan( ...@@ -838,7 +801,7 @@ def scan(
if condition is not None: if condition is not None:
outputs.append(condition) outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs] fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone( fake_outputs = utils.clone(
outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs)) outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs))
) )
all_inputs = filter( all_inputs = filter(
...@@ -927,7 +890,7 @@ def scan( ...@@ -927,7 +890,7 @@ def scan(
if isinstance(new_var.type, ops.expandable_types): if isinstance(new_var.type, ops.expandable_types):
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand_empty( utils.expand_empty(
tt.unbroadcast(tt.shape_padleft(input.variable), 0), tt.unbroadcast(tt.shape_padleft(input.variable), 0),
actual_n_steps, actual_n_steps,
) )
...@@ -1065,7 +1028,7 @@ def scan( ...@@ -1065,7 +1028,7 @@ def scan(
else: else:
new_givens = givens new_givens = givens
new_outs = scan_utils.clone(inner_outs, replace=new_givens) new_outs = utils.clone(inner_outs, replace=new_givens)
## ##
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
...@@ -1095,7 +1058,7 @@ def scan( ...@@ -1095,7 +1058,7 @@ def scan(
info["allow_gc"] = allow_gc info["allow_gc"] = allow_gc
info["strict"] = strict info["strict"] = strict
local_op = scan_op.Scan(inner_inputs, new_outs, info) local_op = Scan(inner_inputs, new_outs, info)
## ##
# Step 8. Compute the outputs using the scan op # Step 8. Compute the outputs using the scan op
......
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论