提交 692d173b authored 作者: Frederic Bastien's avatar Frederic Bastien

pre merge constant.

上级 a614fa0c
......@@ -344,6 +344,37 @@ def MergeOptMerge(opt):
return SeqOptimizer([merger, opt, merger])
def pre_constant_merge(vars):
"""
Merge constants before the variables in the list `out`
Modify the nodes that are in the path to create out.
:note: This don't change node that are in an env.
This is used to pre-merge node generated in an optimization
that we don't want DebugMode to check as there is too many
"""
seen_var = set()
const_sig = {} # variable -> variable.signature() (for constants)
const_sig_inv = {} # signature -> variable (for constants)
def recursive_merge(var):
if var in seen_var:
return var
if var.owner and hasattr(var.owner, "env"):
return var
seen_var.add(var)
if isinstance(var, graph.Constant):
sig = var.signature()
if sig in const_sig_inv:
return const_sig_inv[sig]
const_sig_inv[sig] = var
return var
if var.owner:
for idx,inp in enumerate(var.owner.inputs):
var.owner.inputs[idx] = recursive_merge(inp)
return var
return map(recursive_merge, vars)
########################
### Local Optimizers ###
......
......@@ -25,7 +25,7 @@ import basic as T
from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof.opt import Optimizer, pre_constant_merge
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value
......@@ -1415,9 +1415,9 @@ def merge_two_slices(slice1, len1, slice2, len2):
stop = greedy_local_optimizer( list_opt, stop)
step = greedy_local_optimizer( list_opt, step)
#start = theano.printing.Print('start')(start)
#stop = theano.printing.Print('stop')(stop)
#step = theano.printing.Print('step')(step)
#Pre merge constant for the same reason.
start, stop, step = pre_constant_merge([start, stop, step])
return slice(start, stop, step)
@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论