提交 6626f485 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization for scan that removes constants and useless inputs

The scope is to make the computational graph more readable, and to move constants in the inner graph, allowing constant folding in the inner graph to do its job (performance).
上级 27d60c4f
......@@ -43,6 +43,75 @@ def warning(*msg):
def info(*msg):
_logger.info('INFO theano.scan: '+' '.join(msg))
@gof.local_optimizer([None])
def remove_constants_and_unused_inputs_scan(node):
if not isinstance(node.op, scan_op.Scan):
return False
op = node.op
# We only need to take care of sequences and other arguments
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot
st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs,
'')
out_stuff_inner = op_ins[op.n_seqs:st]
non_seqs = op_ins[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1+op.n_seqs:st]
givens = {}
nw_inner = []
nw_outer = [node.inputs[0]]
all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and
node.inputs[idx+1].tag.unique_value is not None):
try:
val = tensor.get_constant_value(node.inputs[idx+1],
return_ndarray = True)
givens[op_ins[idx]] = tensor.constant(val[0])
except TypeError:
pass
elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx+1]]
nw_n_seqs = len(nw_inner)
# Add outputs stuff
nw_inner += out_stuff_inner
nw_outer += out_stuff_outer
# Look through non sequences
for nw_in, nw_out in zip(non_seqs, outer_non_seqs):
if isinstance(nw_out, tensor.Constant):
givens[nw_in] = nw_out.clone()
elif nw_in in all_ins:
nw_inner += [nw_in]
nw_outer += [nw_out]
if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace = givens)
nw_info = op.info.copy()
nw_info['n_seqs'] = nw_n_seqs
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs
return nw_outs
else:
return False
optdb.register( 'scanOp_remove_constants_and_unused_inputs'
, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees = True)
, 1.995
, 'fast_run'
, 'scan')
@gof.local_optimizer([None])
def scan_make_inplace(node):
op = node.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论