提交 90040b7d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization

上级 a2ea4492
......@@ -1456,3 +1456,218 @@ scan_seqopt.register('scanOp_merge_inouts',
3,
'fast_run',
'scan')
@gof.local_optimizer([None])
def scan_pushout_dot1(node):
if not isinstance(node.op, scan_op.Scan):
return False
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value[t])
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node)
seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (out.owner and
isinstance(out.owner.op, theano.tensor.Elemwise) and
isinstance(out.owner.op.scalar_op, theano.scalar.Add) and
inp in out.owner.inputs and
len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor)
and outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0]
if x == inp:
x = out.owner.inputs[1]
# We need to check if x is the result of an outer product
if (x.owner and
isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0]
inp2 = x.owner.inputs[1]
if inp1 in seqs or inp2 in seqs:
new_scan_out = inp2
if inp2 in seqs:
new_scan_out = inp1
idx = sitsot_outs.index(out)
# We've found our pattern and need to construct a new
# scan node to replace this one. For this we need to
# replace the sit_sot output with a nit_sot output
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node)
inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + idx + 1:])
new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1:]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1:]
inner_sitsot_outs = inner_sitsot_outs[:idx] +\
inner_sitsot_outs[idx + 1:]
# add n_steps as the length
#outer_nitsot.append(node.inputs[0])
inner_nitsot_outs.append(new_scan_out)
shape_of = node.fgraph.shape_feature.shape_of
input_shapes = [shape_of[x] for x in node.inputs]
seqs_shape = [x[1:] for x in input_shapes[1:1 + op.n_seqs]]
# mit_mot, mit_sot, sit_sot
n_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
outs_shape = []
for _idx in xrange(n_outs):
for k in op.tap_array[_idx]:
outs_shape += [input_shapes[_idx + op.n_seqs + 1][1:]]
# shared_outs
offset = 1 + op.n_seqs + n_outs
for _idx in xrange(op.n_shared_outs):
outs_shape += [input_shapes[_idx + offset]]
# non_sequences
offset += op.n_nit_sot + op.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(op.inputs)
grab_shape = scan_utils.infer_shape(
[new_scan_out], op.inputs, inner_ins_shapes)[0]
out_equivalent = {}
for in_ns, out_ns in zip(op.inner_non_seqs(op.inputs),
op.outer_non_seqs(node.inputs)):
out_equivalent[in_ns] = out_ns
validator = scan_utils.Validator(
valid=input_shapes,
invalid=op.inputs,
valid_equivalent=out_equivalent)
grab_shape = [validator.check(x) for x in grab_shape]
if numpy.all([x is not None for x in grab_shape]) and \
op.nit_sot_buffers:
new_mem_buffer = tensor.zeros(
[node.inputs[0]] +
[x[0] for x in grab_shape])
new_info['nit_sot_buffers'] = True
else:
new_info['nit_sot_buffers'] = False
_new_inner_inps = (inner_seqs +
inner_mitmot +
inner_mitsot +
inner_sitsot +
inner_shared +
inner_non_seqs)
_new_inner_outs = (inner_mitmot_outs +
inner_mitsot_outs +
inner_sitsot_outs +
inner_nitsot_outs +
inner_shared_outs)
new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph(
_new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
if new_info['nit_sot_buffers']:
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[new_mem_buffer] +
outer_non_seqs)
else:
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
# old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)]
# We need to clip the seq to the number of steps
_out_seq = _out_seq[:node.inputs[0]]
sh0 = _out_seq.shape[0]
sh1 = _out_seq.shape[1]
sh2 = _out_seq.shape[2]
out_seq = _out_seq.dimshuffle(1, 0, 2)
out_seq = out_seq.reshape((sh1, sh0 * sh2))
sh0 = _val.shape[0]
sh1 = _val.shape[1]
sh2 = _val.shape[2]
val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
(_out_seq.shape[0] * _out_seq.shape[1],
_out_seq.shape[2]))
val = _val.dimshuffle(1, 0, 2).reshape(
(_val.shape[1],
_val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
outer_sitsot_outs = (outer_sitsot_outs[:idx] +
[new_out] +
outer_sitsot_outs[idx:])
final_outs = (outer_mitmot_outs +
outer_mitsot_outs +
outer_sitsot_outs +
outer_nitsot_outs +
outer_shared_outs)
return final_outs
return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论