提交 d89bd8ec authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5502 from ReyhaneAskari/io_toposort_5042

IO_toposort
...@@ -53,9 +53,9 @@ can lead to errors. Consider this example: ...@@ -53,9 +53,9 @@ can lead to errors. Consider this example:
>>> theano.printing.debugprint(f) # doctest: +NORMALIZE_WHITESPACE >>> theano.printing.debugprint(f) # doctest: +NORMALIZE_WHITESPACE
MakeVector{dtype='int64'} [id A] '' 4 MakeVector{dtype='int64'} [id A] '' 4
|Elemwise{Add}[(0, 0)] [id B] '' 3 |Elemwise{Add}[(0, 0)] [id B] '' 3
| |Shape_i{0} [id C] '' 1 | |Shape_i{0} [id C] '' 2
| | |x [id D] | | |x [id D]
| |Shape_i{0} [id E] '' 2 | |Shape_i{0} [id E] '' 1
| |y [id F] | |y [id F]
|Shape_i{1} [id G] '' 0 |Shape_i{1} [id G] '' 0
|x [id D] |x [id D]
......
...@@ -56,8 +56,8 @@ class Test_profiling(unittest.TestCase): ...@@ -56,8 +56,8 @@ class Test_profiling(unittest.TestCase):
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l] lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l] lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if theano.config.device == 'cpu': if theano.config.device == 'cpu':
assert "CPU: 4112KB (8204KB)" in the_string, (lines1, lines2) assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (12296KB)" in the_string, (lines1, lines2) assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2) assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert "Minimum peak from all valid apply node order is 4104KB" in the_string, ( assert "Minimum peak from all valid apply node order is 4104KB" in the_string, (
lines1, lines2) lines1, lines2)
......
...@@ -608,6 +608,8 @@ def stack_search(start, expand, mode='bfs', build_inv=False): ...@@ -608,6 +608,8 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
expand : callable expand : callable
When we get to a node, add expand(node) to the list of nodes to visit. When we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None. This function should return a list, or None.
mode : string
'bfs' or 'dfs' for breath first search or depth first search.
Returns Returns
------- -------
...@@ -632,7 +634,7 @@ def stack_search(start, expand, mode='bfs', build_inv=False): ...@@ -632,7 +634,7 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
start_pop = start.popleft start_pop = start.popleft
else: else:
start_pop = start.pop start_pop = start.pop
expand_inv = {} expand_inv = {} # var: clients
while start: while start:
l = start_pop() l = start_pop()
if id(l) not in rval_set: if id(l) not in rval_set:
...@@ -878,7 +880,7 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None): ...@@ -878,7 +880,7 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
return memo return memo
def general_toposort(r_out, deps, debug_print=False, def general_toposort(outputs, deps, debug_print=False,
compute_deps_cache=None, deps_cache=None, compute_deps_cache=None, deps_cache=None,
clients=None): clients=None):
""" """
...@@ -932,9 +934,9 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -932,9 +934,9 @@ def general_toposort(r_out, deps, debug_print=False,
return deps_cache[io] return deps_cache[io]
assert deps_cache is not None assert deps_cache is not None
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(outputs, (tuple, list, deque))
reachable, _clients = stack_search(deque(r_out), compute_deps_cache, reachable, _clients = stack_search(deque(outputs), compute_deps_cache,
'dfs', True) 'dfs', True)
if clients is not None: if clients is not None:
clients.update(_clients) clients.update(_clients)
...@@ -948,9 +950,9 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -948,9 +950,9 @@ def general_toposort(r_out, deps, debug_print=False,
rlist.append(node) rlist.append(node)
rset.add(node) rset.add(node)
for client in _clients.get(node, []): for client in _clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] d = [a for a in deps_cache[client] if a is not node]
if a is not node] deps_cache[client] = d
if not deps_cache[client]: if not d:
sources.append(client) sources.append(client)
if len(rlist) != len(reachable): if len(rlist) != len(reachable):
...@@ -980,17 +982,37 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -980,17 +982,37 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
node->clients for each node in the subgraph that is sorted node->clients for each node in the subgraph that is sorted
""" """
# the inputs are used only here in the function that decides what 'predecessors' to explore if not orderings and clients is None: # ordering can be None or empty dict
iset = set(inputs) # Specialized function that is faster when more then ~10 nodes
# when no ordering.
# We build 2 functions as a speed up
deps_cache = {} # Do a new stack implementation with the vm algo.
# This will change the order returned.
computed = set(inputs)
todo = [o.owner for o in reversed(outputs) if o.owner]
order = []
while todo:
cur = todo.pop()
# We suppose that all outputs are always computed
if cur.outputs[0] in computed:
continue
if all([i in computed or i.owner is None for i in cur.inputs]):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return order
compute_deps = None compute_deps = None
compute_deps_cache = None compute_deps_cache = None
if not orderings: # can be None or empty dict iset = set(inputs)
deps_cache = {}
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering. # Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up. # Also include the cache in the function itself for speed up.
def compute_deps_cache(obj): def compute_deps_cache(obj):
if obj in deps_cache: if obj in deps_cache:
return deps_cache[obj] return deps_cache[obj]
...@@ -1013,6 +1035,9 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -1013,6 +1035,9 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
deps_cache[obj] = rval deps_cache[obj] = rval
return rval return rval
else: else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
def compute_deps(obj): def compute_deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
...@@ -1023,7 +1048,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -1023,7 +1048,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
rval = list(obj.inputs) rval = list(obj.inputs)
rval.extend(orderings.get(obj, [])) rval.extend(orderings.get(obj, []))
else: else:
assert not orderings.get(obj, []) assert not orderings.get(obj, None)
return rval return rval
topo = general_toposort(outputs, deps=compute_deps, topo = general_toposort(outputs, deps=compute_deps,
......
...@@ -212,7 +212,7 @@ class TestToposort: ...@@ -212,7 +212,7 @@ class TestToposort:
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(r3, r4) o1 = MyOp.make_node(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
assert all == [o1, o0] assert all == [o1, o0] or all == [o0, o1]
def test_4(self): def test_4(self):
"""Test inputs and outputs mixed together in a chain graph""" """Test inputs and outputs mixed together in a chain graph"""
......
...@@ -153,7 +153,7 @@ class TestWrapLinker(unittest.TestCase): ...@@ -153,7 +153,7 @@ class TestWrapLinker(unittest.TestCase):
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
assert nodes == [div, add, mul] assert nodes == [div, add, mul] or nodes == [add, div, mul]
assert o[0].data is None assert o[0].data is None
def test_1(self): def test_1(self):
...@@ -171,7 +171,7 @@ class TestWrapLinker(unittest.TestCase): ...@@ -171,7 +171,7 @@ class TestWrapLinker(unittest.TestCase):
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
assert nodes == [div, add, mul] assert nodes == [div, add, mul] or nodes == [add, div, mul]
assert o[0].data == 1.5 assert o[0].data == 1.5
......
...@@ -1572,12 +1572,13 @@ class UsmmTests(unittest.TestCase): ...@@ -1572,12 +1572,13 @@ class UsmmTests(unittest.TestCase):
# Usmm is tested at the same time in debugmode # Usmm is tested at the same time in debugmode
# Check if the optimization local_usmm and local_usmm_csx is # Check if the optimization local_usmm and local_usmm_csx is
# applied # applied
assert isinstance(topo[0].op, def check_once(x):
theano.sparse.basic.CSMProperties) assert sum([isinstance(n.op, x) for n in topo]) == 1
assert isinstance(topo[1].op, theano.tensor.DimShuffle) check_once(theano.sparse.basic.CSMProperties)
assert isinstance(topo[2].op, theano.tensor.Subtensor) check_once(theano.tensor.DimShuffle)
assert topo[3].op == theano.tensor.neg check_once(theano.tensor.Subtensor)
assert isinstance(topo[4].op, UsmmCscDense) check_once(UsmmCscDense)
check_once(theano.tensor.Elemwise)
if inplace: if inplace:
assert topo[4].op.inplace assert topo[4].op.inplace
elif not fast_compile: elif not fast_compile:
......
...@@ -1629,7 +1629,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1629,7 +1629,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
# the version corresponding to the c code in this Op # the version corresponding to the c code in this Op
version = [6] version = [7]
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply( scalar_node = Apply(
......
...@@ -100,13 +100,13 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -100,13 +100,13 @@ def make_checks(loop_orders, dtypes, sub):
check += """ check += """
if (%%(lv%(j0)s)s_n%(x0)s != %%(lv%(j)s)s_n%(x)s) if (%%(lv%(j0)s)s_n%(x0)s != %%(lv%(j)s)s_n%(x)s)
{ {
PyErr_Format(PyExc_ValueError, "Input dimension mis-match. (input[%%%%i].shape[%%%%i] = %%%%i, input[%%%%i].shape[%%%%i] = %%%%i)", PyErr_Format(PyExc_ValueError, "Input dimension mis-match. (input[%%%%i].shape[%%%%i] = %%%%lli, input[%%%%i].shape[%%%%i] = %%%%lli)",
%(j0)s, %(j0)s,
%(x0)s, %(x0)s,
%%(lv%(j0)s)s_n%(x0)s, (long long int) %%(lv%(j0)s)s_n%(x0)s,
%(j)s, %(j)s,
%(x)s, %(x)s,
%%(lv%(j)s)s_n%(x)s (long long int) %%(lv%(j)s)s_n%(x)s
); );
%%(fail)s %%(fail)s
} }
......
...@@ -256,8 +256,10 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -256,8 +256,10 @@ class T_sigmoid_opts(unittest.TestCase):
[x, y], [x, y],
(sigmoid(x) * sigmoid(-y) * -tensor.exp(-x) * (sigmoid(x) * sigmoid(-y) * -tensor.exp(-x) *
tensor.exp(x * y) * tensor.exp(y)), mode=m) tensor.exp(x * y) * tensor.exp(y)), mode=m)
match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid, topo = f.maker.fgraph.toposort()
tensor.mul]) for op, nb in [(sigmoid, 2), (tensor.mul, 2),
(tensor.neg, 1), (tensor.exp, 1)]:
assert sum([n.op == op for n in topo]) == nb
# assert check_stack_trace(f, ops_to_check=[sigmoid, tensor.mul, # assert check_stack_trace(f, ops_to_check=[sigmoid, tensor.mul,
# tensor.exp]) # tensor.exp])
......
...@@ -1568,14 +1568,19 @@ def test_log1p(): ...@@ -1568,14 +1568,19 @@ def test_log1p():
y = fmatrix() y = fmatrix()
f = function([x, y], T.log(tensor.fill(y, 1) + (x)), mode=m) f = function([x, y], T.log(tensor.fill(y, 1) + (x)), mode=m)
# the first three ops are Shape_i, Shape_i, and Dimshuffle # the first three ops are Shape_i, Shape_i, and Dimshuffle
assert [node.op for node in f.maker.fgraph.toposort()][3:] == [ topo = f.maker.fgraph.toposort()
T.log1p, tensor.alloc] assert topo[-1].op == tensor.alloc
assert T.log1p in [node.op for node in topo]
f = function([x, y], T.log(0 + (x) + tensor.fill(y, 1.0)), mode=m) f = function([x, y], T.log(0 + (x) + tensor.fill(y, 1.0)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()][3:] == [ topo = f.maker.fgraph.toposort()
T.log1p, tensor.alloc] assert topo[-1].op == tensor.alloc
assert T.log1p in [node.op for node in topo]
f = function([x, y], T.log(2 + (x) - tensor.fill(y, 1.0)), mode=m) f = function([x, y], T.log(2 + (x) - tensor.fill(y, 1.0)), mode=m)
assert ([node.op for node in f.maker.fgraph.toposort()][3:] == topo = f.maker.fgraph.toposort()
[T.log1p, tensor.alloc]) assert topo[-1].op == tensor.alloc
assert T.log1p in [node.op for node in topo]
f([1e-7, 10], [[0, 0], [0, 0]]) # debugmode will verify values f([1e-7, 10], [[0, 0], [0, 0]]) # debugmode will verify values
...@@ -2207,8 +2212,9 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -2207,8 +2212,9 @@ class test_local_subtensor_lift(unittest.TestCase):
assert isinstance(prog[0].op, tensor.DimShuffle) assert isinstance(prog[0].op, tensor.DimShuffle)
assert isinstance(prog[1].op.scalar_op, theano.scalar. assert isinstance(prog[1].op.scalar_op, theano.scalar.
Composite) # Composite{add,exp} Composite) # Composite{add,exp}
assert prog[2].op == tensor.add assert prog[2].op == tensor.add or prog[3].op == tensor.add
assert isinstance(prog[3].op, tensor.Subtensor) # first subtensor # first subtensor
assert isinstance(prog[2].op, tensor.Subtensor) or isinstance(prog[3].op, tensor.Subtensor)
assert len(prog) == 4 assert len(prog) == 4
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
......
...@@ -252,11 +252,11 @@ def test_debugprint(): ...@@ -252,11 +252,11 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = '\n'.join([ reference = '\n'.join([
"Elemwise{add,no_inplace} [id A] '' 0 clients:[('[id B]', 1), ('output', '')]", "Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |A [id D]", " |A [id D]",
" |B [id E]", " |B [id E]",
"Elemwise{sub,no_inplace} [id B] '' 1", "Elemwise{sub,no_inplace} [id C] '' 1",
" |Elemwise{add,no_inplace} [id A] '' 0 clients:[('[id B]', 1), ('output', '')]", " |Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |D [id F]", " |D [id F]",
]) + '\n' ]) + '\n'
if s != reference: if s != reference:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论