提交 2a03db63 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

transform optimization in a global one to deal with some issues

上级 63353634
......@@ -1410,170 +1410,177 @@ def scan_merge_inouts(node):
return na.outer_outputs
@gof.local_optimizer([None])
def scan_pushout_dot1(node):
if not isinstance(node.op, scan_op.Scan):
return False
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node)
seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (out.owner and
isinstance(out.owner.op, theano.tensor.Elemwise) and
isinstance(out.owner.op.scalar_op, theano.scalar.Add) and
inp in out.owner.inputs and
len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor)
and outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0]
if x == inp:
x = out.owner.inputs[1]
# We need to check if x is the result of an outer product
if (x.owner and
isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0]
inp2 = x.owner.inputs[1]
if inp1 in seqs or inp2 in seqs:
new_scan_out = inp2
if inp2 in seqs:
new_scan_out = inp1
idx = sitsot_outs.index(out)
# We've found our pattern and need to construct a new
# scan node to replace this one. For this we need to
# replace the sit_sot output with a nit_sot output
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node)
inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + idx + 1:])
new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1:]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1:]
inner_sitsot_outs = inner_sitsot_outs[:idx] +\
inner_sitsot_outs[idx + 1:]
# add n_steps as the length
inner_nitsot_outs.append(new_scan_out)
_new_inner_inps = (inner_seqs +
inner_mitmot +
inner_mitsot +
inner_sitsot +
inner_shared +
inner_non_seqs)
_new_inner_outs = (inner_mitmot_outs +
inner_mitsot_outs +
inner_sitsot_outs +
inner_nitsot_outs +
inner_shared_outs)
new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph(
_new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
# old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)]
# We need to clip the seq to the number of steps
_out_seq = _out_seq[:node.inputs[0]]
sh0 = _out_seq.shape[0]
sh1 = _out_seq.shape[1]
sh2 = _out_seq.shape[2]
out_seq = _out_seq.dimshuffle(1, 0, 2)
out_seq = out_seq.reshape((sh1, sh0 * sh2))
sh0 = _val.shape[0]
sh1 = _val.shape[1]
sh2 = _val.shape[2]
val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
(_out_seq.shape[0] * _out_seq.shape[1],
_out_seq.shape[2]))
class PushOutDot1(gof.Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self):
Optimizer.__init__(self)
val = _val.dimshuffle(1, 0, 2).reshape(
(_val.shape[1],
_val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node)
seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (out.owner and
isinstance(out.owner.op, theano.tensor.Elemwise) and
isinstance(out.owner.op.scalar_op, theano.scalar.Add) and
inp in out.owner.inputs and
len(outer_out.clients) == 1 and
not isinstance(outer_out.clients[0][0], str) and
isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor)
and outer_out.clients[0][0].op.idx_list == (-1,)):
x = out.owner.inputs[0]
if x == inp:
x = out.owner.inputs[1]
# We need to check if x is the result of an outer product
if (x.owner and
isinstance(x.owner.op, theano.tensor.Dot) and
x.owner.inputs[0].ndim == 2 and
x.owner.inputs[1].ndim == 2):
# We need to check if any of the inputs are a sequence
inp1 = x.owner.inputs[0]
inp2 = x.owner.inputs[1]
if inp1 in seqs or inp2 in seqs:
new_scan_out = inp2
if inp2 in seqs:
new_scan_out = inp1
idx = sitsot_outs.index(out)
# We've found our pattern and need to construct a new
# scan node to replace this one. For this we need to
# replace the sit_sot output with a nit_sot output
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node)
inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info['tap_array'] = (new_info['tap_array'][:st + idx] +
new_info['tap_array'][st + idx + 1:])
new_info['n_sit_sot'] -= 1
new_info['n_nit_sot'] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1:]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1:]
inner_sitsot_outs = inner_sitsot_outs[:idx] +\
inner_sitsot_outs[idx + 1:]
# add n_steps as the length
inner_nitsot_outs.append(new_scan_out)
_new_inner_inps = (inner_seqs +
inner_mitmot +
inner_mitsot +
inner_sitsot +
inner_shared +
inner_non_seqs)
_new_inner_outs = (inner_mitmot_outs +
inner_mitsot_outs +
inner_sitsot_outs +
inner_nitsot_outs +
inner_shared_outs)
new_inner_inps, new_inner_outs =\
scan_utils.reconstruct_graph(
_new_inner_inps, _new_inner_outs)
new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
new_info)
_scan_inputs = ([node.inputs[0]] +
outer_seqs +
outer_mitmot +
outer_mitsot +
outer_sitsot +
outer_shared +
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
# old ones
outer_mitmot_outs = new_op.outer_mitmot_outs(new_outs)
outer_mitsot_outs = new_op.outer_mitsot_outs(new_outs)
outer_sitsot_outs = new_op.outer_sitsot_outs(new_outs)
outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
outer_shared_outs = new_op.outer_shared_outs(new_outs)
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)]
# We need to clip the seq to the number of steps
_out_seq = _out_seq[:node.inputs[0]]
sh0 = _out_seq.shape[0]
sh1 = _out_seq.shape[1]
sh2 = _out_seq.shape[2]
out_seq = _out_seq.dimshuffle(1, 0, 2)
out_seq = out_seq.reshape((sh1, sh0 * sh2))
sh0 = _val.shape[0]
sh1 = _val.shape[1]
sh2 = _val.shape[2]
val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
(_out_seq.shape[0] * _out_seq.shape[1],
_out_seq.shape[2]))
outer_sitsot_outs = (outer_sitsot_outs[:idx] +
[new_out] +
outer_sitsot_outs[idx:])
final_outs = (outer_mitmot_outs +
outer_mitsot_outs +
outer_sitsot_outs +
outer_nitsot_outs +
outer_shared_outs)
val = _val.dimshuffle(1, 0, 2).reshape(
(_val.shape[1],
_val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq)
return final_outs
pos = node.outputs.index(outer_out)
old_new = zip(node.outputs[:pos], new_outs[:pos])
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new,
remove = [node],
reason='PushOutDot1')
return False
# I've added an equilibrium because later scan optimization in the sequence
......@@ -1625,7 +1632,7 @@ scan_seqopt1.register('scanOp_pushout_seqs_ops',
scan_seqopt1.register('scan_pushout_dot1',
opt.in2out(scan_pushout_dot1, ignore_newtrees=True),
PushOutDot1(),
4,
'fast_run',
'more_mem',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论