提交 28e06d1f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4125 from abergeron/fix_scan_bug

Fix scan pushout
...@@ -873,7 +873,8 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None): ...@@ -873,7 +873,8 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
def general_toposort(r_out, deps, debug_print=False, def general_toposort(r_out, deps, debug_print=False,
compute_deps_cache=None, deps_cache=None): compute_deps_cache=None, deps_cache=None,
clients=None):
""" """
WRITEME WRITEME
...@@ -886,6 +887,9 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -886,6 +887,9 @@ def general_toposort(r_out, deps, debug_print=False,
deps, but that also cache its results in a dict passed as deps_cache. deps, but that also cache its results in a dict passed as deps_cache.
deps_cache : dict deps_cache : dict
Must be used with compute_deps_cache. Must be used with compute_deps_cache.
clients : dict
If a dict is passed it will be filled with a mapping of node
-> clients for each node in the subgraph.
Notes Notes
----- -----
...@@ -924,8 +928,10 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -924,8 +928,10 @@ def general_toposort(r_out, deps, debug_print=False,
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search(deque(r_out), compute_deps_cache, reachable, _clients = stack_search(deque(r_out), compute_deps_cache,
'dfs', True) 'dfs', True)
if clients is not None:
clients.update(_clients)
sources = deque([r for r in reachable if not deps_cache.get(r, None)]) sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set() rset = set()
...@@ -935,7 +941,7 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -935,7 +941,7 @@ def general_toposort(r_out, deps, debug_print=False,
if node not in rset: if node not in rset:
rlist.append(node) rlist.append(node)
rset.add(node) rset.add(node)
for client in clients.get(node, []): for client in _clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] deps_cache[client] = [a for a in deps_cache[client]
if a is not node] if a is not node]
if not deps_cache[client]: if not deps_cache[client]:
...@@ -951,7 +957,7 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -951,7 +957,7 @@ def general_toposort(r_out, deps, debug_print=False,
return rlist return rlist
def io_toposort(inputs, outputs, orderings=None): def io_toposort(inputs, outputs, orderings=None, clients=None):
""" """
WRITEME WRITEME
...@@ -959,10 +965,13 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -959,10 +965,13 @@ def io_toposort(inputs, outputs, orderings=None):
---------- ----------
inputs : list or tuple of Variable instances inputs : list or tuple of Variable instances
outputs : list or tuple of Apply instances outputs : list or tuple of Apply instances
orderings: dict orderings : dict
Key: Apply instance. Value: list of Apply instance. Key: Apply instance. Value: list of Apply instance.
It is important that the value be a container with a deterministic It is important that the value be a container with a deterministic
iteration order. No sets allowed! iteration order. No sets allowed!
clients : dict
If a dict is provided it will be filled with mappings of
node->clients for each node in the subgraph that is sorted
""" """
# the inputs are used only here in the function that decides what 'predecessors' to explore # the inputs are used only here in the function that decides what 'predecessors' to explore
...@@ -1013,7 +1022,7 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -1013,7 +1022,7 @@ def io_toposort(inputs, outputs, orderings=None):
topo = general_toposort(outputs, deps=compute_deps, topo = general_toposort(outputs, deps=compute_deps,
compute_deps_cache=compute_deps_cache, compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache) deps_cache=deps_cache, clients=clients)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
...@@ -56,7 +56,7 @@ from sys import maxsize ...@@ -56,7 +56,7 @@ from sys import maxsize
import numpy import numpy
import theano import theano
from theano import tensor from theano import tensor, scalar
from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof from theano import gof
from theano.compat import OrderedDict from theano.compat import OrderedDict
...@@ -694,98 +694,16 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -694,98 +694,16 @@ class PushOutScanOutput(gof.Optimizer):
op.inputs, op.outputs, op.info) op.inputs, op.outputs, op.info)
new_scan_node = None new_scan_node = None
local_fgraph_topo = theano.gof.graph.io_toposort(op.inputs, op.outputs) clients = {}
local_fgraph_topo = theano.gof.graph.io_toposort(args.inner_inputs,
args.inner_outputs,
clients=clients)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.elemwise.Elemwise) and
nd.out in args.inner_out_nit_sot): isinstance(nd.op.scalar_op, scalar.Add) and
""" nd.out in args.inner_out_sit_sot and
The following optimization involves pushing out, after the self.inner_sitsot_only_last_step_used(nd.out, args)):
scan, a Dot whose output is nitsot (not feed back to the inner
graph) and where one input is one of scan's input with ndim=2
and the other is an intermediate variable in the Scan inner
graph with ndim=1.
The Dot product is pushed out of the scan and its inputs are
now the original matrix and a new matrix obtained by
concatenating the vectors into a matrix.
"""
# Ensure that the output of the Dot is used in the outer
# graph to avoid apply the optimization needlessly
dot_out_nitsot_idx = args.inner_out_nit_sot.index(nd.out)
outer_dot_output = args.outer_out_nit_sot[dot_out_nitsot_idx]
if len(outer_dot_output.clients) == 0:
continue
"""
Validate that one of the inputs is a matrix AND a
non-sequence input to scan and that the other input is a
vector and either an sequence input to scan or the result
of computation in the inner function of scan.
"""
valid_inputs = False
idx_matrix_input = -1
idx_vector_input = -1
if (nd.inputs[0].ndim == 2 and
(nd.inputs[0] in args.inner_in_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
(nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)):
valid_inputs = True
idx_matrix_input = 0
idx_vector_input = 1
elif (nd.inputs[1].ndim == 2 and
(nd.inputs[1] in args.inner_in_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
(nd.inputs[0] in args.inner_in_seqs or
nd.inputs[0] not in args.inner_inputs)):
valid_inputs = True
idx_matrix_input = 1
idx_vector_input = 0
if valid_inputs:
# The optimization can be applied on the current Dot
# Move out of scan the two inputs to the Dot
(outer_vars,
new_scan_node,
new_scan_args) = self.push_out_inner_vars(fgraph,
nd.inputs,
node, args)
outer_vector_input = outer_vars[idx_vector_input]
outer_matrix_input = outer_vars[idx_matrix_input]
# Perform the Dot outside of scan
if idx_matrix_input == 0:
outer_dot_inputs = [outer_vector_input,
outer_matrix_input.transpose()]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
else: # idx_matrix_input == 1
outer_dot_inputs = [outer_vector_input,
outer_matrix_input]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
fgraph.replace_all(
[(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
outer_dot_output)],
reason="scanOp_pushout_output")
break
elif (isinstance(nd.op, theano.tensor.elemwise.Elemwise) and
isinstance(nd.op.nfunc, numpy.ufunc) and
nd.op.nfunc.__name__ == 'add' and
nd.out in args.inner_out_sit_sot and
self.inner_sitsot_only_last_step_used(nd.out, args)):
# Ensure that one of the input to the add is the output of # Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function # the add from a previous iteration of the inner function
...@@ -809,7 +727,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -809,7 +727,7 @@ class PushOutScanOutput(gof.Optimizer):
if (dot_input.owner is not None and if (dot_input.owner is not None and
isinstance(dot_input.owner.op, theano.tensor.Dot) and isinstance(dot_input.owner.op, theano.tensor.Dot) and
len(dot_input.clients) == 1 and len(clients[dot_input]) == 1 and
dot_input.owner.inputs[0].ndim == 2 and dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
......
...@@ -3120,6 +3120,25 @@ class T_Scan(unittest.TestCase): ...@@ -3120,6 +3120,25 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(vnh0, tnh0, atol=1e-6) utt.assert_allclose(vnh0, tnh0, atol=1e-6)
utt.assert_allclose(vnW, tnW, atol=1e-6) utt.assert_allclose(vnW, tnW, atol=1e-6)
def test_pushout_dot(self):
W = tensor.matrix('W')
h = tensor.matrix('h')
o, _ = theano.scan(lambda hi, him1, W: (hi, tensor.dot(hi+him1, W)),
outputs_info=[tensor.zeros([h.shape[1]]), None],
sequences=[h],
non_sequences=[W])
f = theano.function([W, h], o, mode=mode_with_opt)
scan_nodes = [x for x in f.maker.fgraph.toposort()
if isinstance(x.op,
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1
scan_op = scan_nodes[0].op
assert not any(isinstance(n.op, tensor.Dot) for n in
scan_op.fn.maker.fgraph.apply_nodes)
def test_pushout_all(self): def test_pushout_all(self):
W1 = tensor.matrix('W1') W1 = tensor.matrix('W1')
W2 = tensor.matrix('W2') W2 = tensor.matrix('W2')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论