提交 4258ebb1 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Support case where vector input is already an output of scan

上级 56d71501
......@@ -612,6 +612,11 @@ class PushOutScanOutput(gof.Optimizer):
op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of
# use
args = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
# Obtain the list containing the indices, in clean_outputs, of the
# scan op's outputs that are nit_sot (not fed back to the inner fct.)
nitsot_outs = op.inner_nitsot_outs(node.outputs)
......@@ -678,8 +683,7 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[0] in inner_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
nd.inputs[1] not in clean_inputs and
nd.inputs[1] not in clean_outputs):
nd.inputs[1] not in clean_inputs):
valid_inputs = True
idx_matrix_input = 0
......@@ -689,8 +693,7 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[1] in inner_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
nd.inputs[0] not in clean_inputs and
nd.inputs[0] not in clean_outputs):
nd.inputs[0] not in clean_inputs):
valid_inputs = True
idx_matrix_input = 1
......@@ -717,34 +720,44 @@ class PushOutScanOutput(gof.Optimizer):
'this on theano-users list'),
inner_matrix_input)
# Add the new outputs to the scan (get as output the variables of
# the outer graph corresponding to the new scan outputs
new_output_inner = nd.inputs[idx_vector_input]
_new_scan_node, idx_old_outputs, idx_new_output = self.add_nitsot_outputs(
fgraph, node,
clean_inputs,
clean_outputs,
new_output_inner)
new_outer_output = _new_scan_node.outputs[idx_new_output]
# Perform the Dot on the new scan output.
# If the vector_input is already a nit_sot output of the
# scan, get a reference to the corresponding outer output.
# Otherwise, add it as a new nit_sot output and then get a
# reference to it
if nd.inputs[idx_vector_input] in nitsot_outs:
# Figure out which scan output corresponds the vector
# input
inner_vector_input = nd.inputs[idx_vector_input]
vector_input_nitsot_idx = args.inner_out_nit_sot.index(inner_vector_input)
outer_vector_input = args.outer_out_nit_sot[vector_input_nitsot_idx]
else:
# Add the vector_input as a new nitsot output to scan
new_output_inner = nd.inputs[idx_vector_input]
new_scan_node, idx_old_outputs, idx_new_output = self.add_nitsot_outputs(
fgraph, node,
clean_inputs,
clean_outputs,
new_output_inner)
outer_vector_input = new_scan_node.outputs[idx_new_output]
# Perform the Dot outside of scan
if idx_matrix_input == 0:
outer_dot_inputs = [outer_matrix_input,
new_outer_output.transpose()]
outer_vector_input.transpose()]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs).transpose()
else: # idx_matrix_input == 1
outer_dot_inputs = [new_outer_output,
outer_dot_inputs = [outer_vector_input,
outer_matrix_input]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
new_idx_dot_output = idx_old_outputs[idx_dot_output]
fgraph.replace_all([
(_new_scan_node.outputs[new_idx_dot_output],
(new_scan_node.outputs[new_idx_dot_output],
outer_dot_output)],
reason="scanOp_pushout_output")
new_scan_node = _new_scan_node
break
return new_scan_node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论