提交 b75cf2e1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2823 from nouiz/faster_test

finish gh-2768
...@@ -717,7 +717,7 @@ def clone_get_equiv(inputs, outputs, ...@@ -717,7 +717,7 @@ def clone_get_equiv(inputs, outputs,
def general_toposort(r_out, deps, debug_print=False, def general_toposort(r_out, deps, debug_print=False,
_deps=None, deps_cache=None): compute_deps_cache=None, deps_cache=None):
"""WRITEME """WRITEME
:note: :note:
...@@ -729,15 +729,24 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -729,15 +729,24 @@ def general_toposort(r_out, deps, debug_print=False,
:note: :note:
The order of the return value list is determined by the order of nodes returned by the deps() function. The order of the return value list is determined by the order of nodes returned by the deps() function.
:param deps: a python function that take a node as input and
return its dependence.
:param compute_deps_cache: Optional,
if provided deps_cache should also be provided. This is a
function like deps, but that also cache its results in a dict
passed as deps_cache.
:param deps_cache: a dict. Must be used with compute_deps_cache.
:note: deps should be provided or can be None and the caller :note: deps should be provided or can be None and the caller
provide _deps and deps_cache. The second option remove a provide compute_deps_cache and deps_cache. The second option
Python function call, so is faster. remove a Python function call, and allow for more specialized
code, so it can be faster.
""" """
if _deps is None: if compute_deps_cache is None:
deps_cache = {} deps_cache = {}
def _deps(io): def compute_deps_cache(io):
if io not in deps_cache: if io not in deps_cache:
d = deps(io) d = deps(io)
if d: if d:
...@@ -751,10 +760,12 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -751,10 +760,12 @@ def general_toposort(r_out, deps, debug_print=False,
return d return d
else: else:
return deps_cache[io] return deps_cache[io]
assert deps_cache is not None
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search(deque(r_out), _deps, 'dfs', True) reachable, clients = stack_search(deque(r_out), compute_deps_cache,
'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)]) sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set() rset = set()
...@@ -800,12 +811,12 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -800,12 +811,12 @@ def io_toposort(inputs, outputs, orderings=None):
# We build 2 functions as a speed up # We build 2 functions as a speed up
deps_cache = {} deps_cache = {}
deps = None compute_deps = None
_deps = None compute_deps_cache = None
if not orderings: # can be None or empty dict if not orderings: # can be None or empty dict
# Specialized function that is faster when no ordering. # Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up. # Also include the cache in the function itself for speed up.
def _deps(obj): def compute_deps_cache(obj):
if obj in deps_cache: if obj in deps_cache:
return deps_cache[io] return deps_cache[io]
rval = [] rval = []
...@@ -827,7 +838,7 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -827,7 +838,7 @@ def io_toposort(inputs, outputs, orderings=None):
deps_cache[obj] = rval deps_cache[obj] = rval
return rval return rval
else: else:
def deps(obj): def compute_deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
if isinstance(obj, Variable): if isinstance(obj, Variable):
...@@ -840,7 +851,8 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -840,7 +851,8 @@ def io_toposort(inputs, outputs, orderings=None):
assert not orderings.get(obj, []) assert not orderings.get(obj, [])
return rval return rval
topo = general_toposort(outputs, deps=deps, _deps=_deps, topo = general_toposort(outputs, deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache) deps_cache=deps_cache)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
...@@ -452,7 +452,7 @@ class test_canonize(unittest.TestCase): ...@@ -452,7 +452,7 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
try:
opt = gof.Query(["canonicalize"]) opt = gof.Query(["canonicalize"])
opt = opt.including('ShapeOpt') opt = opt.including('ShapeOpt')
opt = opt.excluding( opt = opt.excluding(
...@@ -512,7 +512,7 @@ class test_canonize(unittest.TestCase): ...@@ -512,7 +512,7 @@ class test_canonize(unittest.TestCase):
# must broadcast as their is a dimshuffle in the computation # must broadcast as their is a dimshuffle in the computation
((dx/dv)/dx, [dx, dv], [dxv, dvv], 1, 'float64'), ((dx/dv)/dx, [dx, dv], [dxv, dvv], 1, 'float64'),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx/fv)/fx, [fx, fv], [fxv, fvv], 1, 'float32'), ((fx/fv)/fx, [fx, fv], [fxv, fvv], 1, 'float32'),
# topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc] # topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
]): ]):
...@@ -639,8 +639,6 @@ class test_canonize(unittest.TestCase): ...@@ -639,8 +639,6 @@ class test_canonize(unittest.TestCase):
assert numpy.all(numpy.isfinite(out)) assert numpy.all(numpy.isfinite(out))
assert numpy.allclose(out, numpy.sign(val_inputs[0]) * 2 / 3) assert numpy.allclose(out, numpy.sign(val_inputs[0]) * 2 / 3)
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
finally:
pass
def test_abs_mul_div(self): def test_abs_mul_div(self):
""" """
...@@ -701,12 +699,12 @@ class test_canonize(unittest.TestCase): ...@@ -701,12 +699,12 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
try:
opt = gof.Query(["canonicalize"]) opt = gof.Query(["canonicalize"])
opt = opt.excluding( opt = opt.excluding(
'local_elemwise_fusion') 'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail! # test fail!
# test x / y / z -> x / (y * z) # test x / y / z -> x / (y * z)
for (g, sym_inputs, val_inputs, out_dtype) in [ for (g, sym_inputs, val_inputs, out_dtype) in [
((dx/dy)/dz, [dx, dy, dz], [dxv, dyv, dzv], 'float64'), ((dx/dy)/dz, [dx, dy, dz], [dxv, dyv, dzv], 'float64'),
...@@ -743,9 +741,6 @@ class test_canonize(unittest.TestCase): ...@@ -743,9 +741,6 @@ class test_canonize(unittest.TestCase):
assert len(topo[0].inputs) == 1 assert len(topo[0].inputs) == 1
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
finally:
pass
def test_dont_merge_if_multiple_client(self): def test_dont_merge_if_multiple_client(self):
""" test those case take from the comment in Canonizer """ test those case take from the comment in Canonizer
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论