提交 2a03db63 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

transform optimization in a global one to deal with some issues

上级 63353634
......@@ -1410,10 +1410,23 @@ def scan_merge_inouts(node):
return na.outer_outputs
@gof.local_optimizer([None])
def scan_pushout_dot1(node):
if not isinstance(node.op, scan_op.Scan):
return False
class PushOutDot1(gof.Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
......@@ -1520,6 +1533,7 @@ def scan_pushout_dot1(node):
outer_nitsot +
[node.inputs[0]] +
outer_non_seqs)
new_outs = new_op(*_scan_inputs)
# We need now to pair correctly the new outputs with the
......@@ -1547,8 +1561,6 @@ def scan_pushout_dot1(node):
val = _val.reshape((sh0 * sh1, sh2))
new_out = tensor.dot(out_seq, val)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
......@@ -1559,21 +1571,16 @@ def scan_pushout_dot1(node):
(_val.shape[1],
_val.shape[0] * _val.shape[2]))
new_out = tensor.dot(val, out_seq)
new_out = tensor.unbroadcast(
new_out.dimshuffle('x', 0, 1), 0)
outer_sitsot_outs = (outer_sitsot_outs[:idx] +
[new_out] +
outer_sitsot_outs[idx:])
final_outs = (outer_mitmot_outs +
outer_mitsot_outs +
outer_sitsot_outs +
outer_nitsot_outs +
outer_shared_outs)
return final_outs
pos = node.outputs.index(outer_out)
old_new = zip(node.outputs[:pos], new_outs[:pos])
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new,
remove = [node],
reason='PushOutDot1')
return False
# I've added an equilibrium because later scan optimization in the sequence
......@@ -1625,7 +1632,7 @@ scan_seqopt1.register('scanOp_pushout_seqs_ops',
scan_seqopt1.register('scan_pushout_dot1',
opt.in2out(scan_pushout_dot1, ignore_newtrees=True),
PushOutDot1(),
4,
'fast_run',
'more_mem',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论