提交 95dd414b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make PatternSub register the proper trackings and fix a bunch of call sites of…

Make PatternSub register the proper trackings and fix a bunch of call sites of local_optimizer to use the proper calling form. Most of the tests pass now, except a bunch in the Cuda backend.
上级 1c974f4f
...@@ -1022,17 +1022,7 @@ class PatternSub(LocalOptimizer): ...@@ -1022,17 +1022,7 @@ class PatternSub(LocalOptimizer):
return self.op return self.op
def tracks(self): def tracks(self):
def helper(pattern, sofar): return [self.op]
if isinstance(pattern, (list, tuple)):
sofar = sofar + (pattern[0],)
return reduce(tuple.__add__,
tuple(helper(p, sofar) for p in pattern[1:]),
())
elif isinstance(pattern, dict):
return helper(pattern['pattern'], sofar)
else:
return (sofar,)
return set(helper(self.in_pattern, ()))
def transform(self, node): def transform(self, node):
""" """
......
...@@ -384,7 +384,7 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -384,7 +384,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
return tuple(rval) return tuple(rval)
@gof.local_optimizer([None]) @gof.local_optimizer([IfElse])
def cond_make_inplace(node): def cond_make_inplace(node):
op = node.op op = node.op
if isinstance(op, IfElse) and not op.as_view: if isinstance(op, IfElse) and not op.as_view:
...@@ -445,7 +445,7 @@ acceptable_ops = (theano.tensor.basic.Dot, ...@@ -445,7 +445,7 @@ acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.elemwise.DimShuffle) theano.tensor.elemwise.DimShuffle)
@gof.local_optimizer([None]) @gof.local_optimizer(acceptable_ops)
def ifelse_lift_single_if_through_acceptable_ops(main_node): def ifelse_lift_single_if_through_acceptable_ops(main_node):
"""This optimization lifts up certain ifelse instances. """This optimization lifts up certain ifelse instances.
...@@ -493,7 +493,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node): ...@@ -493,7 +493,7 @@ def ifelse_lift_single_if_through_acceptable_ops(main_node):
return nw_outs return nw_outs
@gof.local_optimizer([None]) @gof.local_optimizer([IfElse])
def cond_merge_ifs_true(node): def cond_merge_ifs_true(node):
op = node.op op = node.op
if not isinstance(op, IfElse): if not isinstance(op, IfElse):
...@@ -517,7 +517,7 @@ def cond_merge_ifs_true(node): ...@@ -517,7 +517,7 @@ def cond_merge_ifs_true(node):
return op(*old_ins, **dict(return_list=True)) return op(*old_ins, **dict(return_list=True))
@gof.local_optimizer([None]) @gof.local_optimizer([IfElse])
def cond_merge_ifs_false(node): def cond_merge_ifs_false(node):
op = node.op op = node.op
if not isinstance(op, IfElse): if not isinstance(op, IfElse):
...@@ -592,7 +592,7 @@ class CondMerge(gof.Optimizer): ...@@ -592,7 +592,7 @@ class CondMerge(gof.Optimizer):
fgraph.replace_all_validate(pairs, reason='cond_merge') fgraph.replace_all_validate(pairs, reason='cond_merge')
@gof.local_optimizer([None]) @gof.local_optimizer([IfElse])
def cond_remove_identical(node): def cond_remove_identical(node):
op = node.op op = node.op
...@@ -643,7 +643,7 @@ def cond_remove_identical(node): ...@@ -643,7 +643,7 @@ def cond_remove_identical(node):
return rval return rval
@gof.local_optimizer([None]) @gof.local_optimizer([IfElse])
def cond_merge_random_op(main_node): def cond_merge_random_op(main_node):
if isinstance(main_node.op, IfElse): if isinstance(main_node.op, IfElse):
return False return False
......
...@@ -72,7 +72,7 @@ def hints(variable): ...@@ -72,7 +72,7 @@ def hints(variable):
@register_canonicalize @register_canonicalize
@local_optimizer([]) @local_optimizer([Hint])
def remove_hint_nodes(node): def remove_hint_nodes(node):
if is_hint_node(node): if is_hint_node(node):
# transfer hints from graph to Feature # transfer hints from graph to Feature
...@@ -224,7 +224,7 @@ def is_positive(v): ...@@ -224,7 +224,7 @@ def is_positive(v):
@register_stabilize @register_stabilize
@local_optimizer([]) @local_optimizer([Dot, Dot22])
def inv_as_solve(node): def inv_as_solve(node):
if not imported_scipy: if not imported_scipy:
return False return False
...@@ -242,7 +242,7 @@ def inv_as_solve(node): ...@@ -242,7 +242,7 @@ def inv_as_solve(node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([]) @local_optimizer([DimShuffle])
def no_transpose_symmetric(node): def no_transpose_symmetric(node):
if isinstance(node.op, DimShuffle): if isinstance(node.op, DimShuffle):
x = node.inputs[0] x = node.inputs[0]
...@@ -253,7 +253,7 @@ def no_transpose_symmetric(node): ...@@ -253,7 +253,7 @@ def no_transpose_symmetric(node):
@register_stabilize @register_stabilize
@local_optimizer([]) @local_optimizer(None) # XXX: solve is defined later and can't be used here
def psd_solve_with_chol(node): def psd_solve_with_chol(node):
if node.op == solve: if node.op == solve:
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
...@@ -269,7 +269,7 @@ def psd_solve_with_chol(node): ...@@ -269,7 +269,7 @@ def psd_solve_with_chol(node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([]) @local_optimizer(None) # XXX: det is defined later and can't be used here
def local_det_chol(node): def local_det_chol(node):
""" """
If we have det(X) and there is already an L=cholesky(X) If we have det(X) and there is already an L=cholesky(X)
...@@ -287,7 +287,7 @@ def local_det_chol(node): ...@@ -287,7 +287,7 @@ def local_det_chol(node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([]) @local_optimizer([tensor.log])
def local_log_prod_sqr(node): def local_log_prod_sqr(node):
if node.op == tensor.log: if node.op == tensor.log:
x, = node.inputs x, = node.inputs
...@@ -307,7 +307,7 @@ def local_log_prod_sqr(node): ...@@ -307,7 +307,7 @@ def local_log_prod_sqr(node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([]) @local_optimizer([tensor.log])
def local_log_pow(node): def local_log_pow(node):
if node.op == tensor.log: if node.op == tensor.log:
x, = node.inputs x, = node.inputs
......
...@@ -337,7 +337,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -337,7 +337,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
""" % locals() """ % locals()
@local_optimizer() @local_optimizer([MultinomialFromUniform])
def local_gpu_multinomial(node): def local_gpu_multinomial(node):
if type(node.op) is MultinomialFromUniform: if type(node.op) is MultinomialFromUniform:
p, u = node.inputs p, u = node.inputs
......
...@@ -941,7 +941,7 @@ class MRG_RandomStreams(object): ...@@ -941,7 +941,7 @@ class MRG_RandomStreams(object):
return final_samples return final_samples
@local_optimizer([None]) @local_optimizer([mrg_uniform])
def mrg_random_make_inplace(node): def mrg_random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, mrg_uniform) and not op.inplace: if isinstance(op, mrg_uniform) and not op.inplace:
......
...@@ -32,7 +32,7 @@ sparse.register_specialize(local_csm_properties_csm) ...@@ -32,7 +32,7 @@ sparse.register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0 # This is tested in tests/test_basic.py:test_remove0
@gof.local_optimizer([None]) @gof.local_optimizer([sparse.Remove0])
def local_inplace_remove0(node): def local_inplace_remove0(node):
""" """
Optimization to insert inplace versions of Remove0. Optimization to insert inplace versions of Remove0.
...@@ -49,7 +49,7 @@ theano.compile.optdb.register('local_inplace_remove0', ...@@ -49,7 +49,7 @@ theano.compile.optdb.register('local_inplace_remove0',
gof.TopoOptimizer(local_inplace_remove0, gof.TopoOptimizer(local_inplace_remove0,
failure_callback=gof.TopoOptimizer.warn_inplace), failure_callback=gof.TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') 60, 'fast_run', 'inplace')
@gof.local_optimizer([None]) @gof.local_optimizer([sparse.AddSD])
def local_inplace_addsd(node): def local_inplace_addsd(node):
""" """
Optimization to insert inplace versions of AddSD. Optimization to insert inplace versions of AddSD.
......
...@@ -266,7 +266,7 @@ def make_gpu_optimizer(op, to_gpu): ...@@ -266,7 +266,7 @@ def make_gpu_optimizer(op, to_gpu):
:param to_gpu: a list of op inputs that are moved to the GPU. :param to_gpu: a list of op inputs that are moved to the GPU.
""" """
@theano.gof.local_optimizer([]) @theano.gof.local_optimizer([op, cuda.gpu_from_host])
def local_to_gpu(node): def local_to_gpu(node):
""" """
op(host_from_gpu()) -> host_from_gpu(op) op(host_from_gpu()) -> host_from_gpu(op)
...@@ -302,7 +302,7 @@ if cuda.cuda_available: ...@@ -302,7 +302,7 @@ if cuda.cuda_available:
make_gpu_optimizer(IncDiagonalSubtensor, [0, 3]) make_gpu_optimizer(IncDiagonalSubtensor, [0, 3])
@theano.gof.local_optimizer([None]) @theano.gof.local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor])
def local_inplace_DiagonalSubtensor(node): def local_inplace_DiagonalSubtensor(node):
""" also work for IncDiagonalSubtensor """ """ also work for IncDiagonalSubtensor """
if (isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) and if (isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论