提交 6f5a06d8 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add case to push out a sum-of-dot from Scan

上级 bdcb0c0d
...@@ -705,6 +705,81 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -705,6 +705,81 @@ class PushOutScanOutput(gof.Optimizer):
break break
elif (isinstance(nd.op, theano.tensor.elemwise.Elemwise) and
isinstance(nd.op.nfunc, numpy.ufunc) and
nd.op.nfunc.__name__ == 'add' and
nd.out in args.inner_out_sit_sot and
self.inner_sitsot_only_last_step_used(nd.out, args)):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
dot_in_idx = 1 - sitsot_in_idx # 0 if sitsot_in_idx==1,
# 1 if sitsot_in_idx==0
dot_input = nd.inputs[dot_in_idx]
if (isinstance(dot_input.owner.op, theano.tensor.Dot) and
len(dot_input.clients) == 1 and
dot_input.owner.inputs[0].ndim == 2 and
dot_input.owner.inputs[1].ndim == 2 and
self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
(outer_dot_inputs,
new_scan_node,
new_scan_args) = self.push_out_inner_vars(fgraph,
inner_dot_inputs,
node, args)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = theano.tensor.flatten(
outer_dot_inputs[0].dimshuffle(1,0,2),
outdim=2)
shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
outer_dot_inputs[1] = outer_dot_inputs[1].reshape((shape_input1[0] *
shape_input1[1],
shape_input1[2]))
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = outer_sitsot.clients[0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
fgraph.replace_all([
(outer_sitsot_last_step, replacement)],
reason="scanOp_pushout_output")
break
return new_scan_node return new_scan_node
def inner_sitsot_only_last_step_used(self, var, scan_args): def inner_sitsot_only_last_step_used(self, var, scan_args):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论