提交 2a03db63 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

transform optimization in a global one to deal with some issues

上级 63353634
...@@ -1410,10 +1410,23 @@ def scan_merge_inouts(node): ...@@ -1410,10 +1410,23 @@ def scan_merge_inouts(node):
return na.outer_outputs return na.outer_outputs
@gof.local_optimizer([None]) class PushOutDot1(gof.Optimizer):
def scan_pushout_dot1(node): """Graph optimizer for Scan(makes it run inplace)"""
if not isinstance(node.op, scan_op.Scan): def __init__(self):
return False Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form # Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value) # x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value # with Sequence.reshape((-1, seq.shape[2])) \dot Value
...@@ -1520,6 +1533,7 @@ def scan_pushout_dot1(node): ...@@ -1520,6 +1533,7 @@ def scan_pushout_dot1(node):
outer_nitsot + outer_nitsot +
[node.inputs[0]] + [node.inputs[0]] +
outer_non_seqs) outer_non_seqs)
new_outs = new_op(*_scan_inputs) new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the # We need now to pair correctly the new outputs with the
...@@ -1547,8 +1561,6 @@ def scan_pushout_dot1(node): ...@@ -1547,8 +1561,6 @@ def scan_pushout_dot1(node):
val = _val.reshape((sh0 * sh1, sh2)) val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val) new_out = tensor.dot(out_seq, val)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
else: else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)] _out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape( out_seq = _out_seq.reshape(
...@@ -1559,21 +1571,16 @@ def scan_pushout_dot1(node): ...@@ -1559,21 +1571,16 @@ def scan_pushout_dot1(node):
(_val.shape[1], (_val.shape[1],
_val.shape[0] * _val.shape[2])) _val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq) new_out = tensor.dot(val, out_seq)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
outer_sitsot_outs = (outer_sitsot_outs[:idx] +
[new_out] +
outer_sitsot_outs[idx:])
final_outs = (outer_mitmot_outs +
outer_mitsot_outs +
outer_sitsot_outs +
outer_nitsot_outs +
outer_shared_outs)
return final_outs pos = node.outputs.index(outer_out)
old_new = zip(node.outputs[:pos], new_outs[:pos])
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new,
remove = [node],
reason='PushOutDot1')
return False
# I've added an equilibrium because later scan optimization in the sequence # I've added an equilibrium because later scan optimization in the sequence
...@@ -1625,7 +1632,7 @@ scan_seqopt1.register('scanOp_pushout_seqs_ops', ...@@ -1625,7 +1632,7 @@ scan_seqopt1.register('scanOp_pushout_seqs_ops',
scan_seqopt1.register('scan_pushout_dot1', scan_seqopt1.register('scan_pushout_dot1',
opt.in2out(scan_pushout_dot1, ignore_newtrees=True), PushOutDot1(),
4, 4,
'fast_run', 'fast_run',
'more_mem', 'more_mem',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论