提交 c8fc0905 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

made the code to use generic input/output list

Before I used to have the different types of inputs/outputs seperated into lists inside the scan op. I've made a more generic approach, where there is a simple list of all inputs/outputs for scan.
上级 4c1da020
......@@ -42,40 +42,30 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class ScanOp(PureOp):
def __init__(self,
inputs,
input_states,
parameters,
non_numeric_input_states,
non_numeric_output_states,
output_states,
outputs,
lengths,
mintaps,
switches,
options,
as_repeatUntil):
self.options = options # name/mode/inplace/gpu/profile
self.inputs = inputs
self.input_states = input_states
self.parameters = parameters
self.non_numeric_input_states = non_numeric_input_states
self.non_numeric_output_states = non_numeric_output_states
self.output_states = output_states
self.outputs = outputs
self.switches = switches
self.lengths = lengths
self.mintaps = mintaps
self.name = name
self.mode = mode
self.inplace = inplace
self.gpu = gpu
self.as_repeatUntil = as_repeatUntil
self.profile = profile
self.options = options
self.name = options['name']
self.mode = options['mode']
self.inplace = options['inplace']
self.gpu = options['gpu']
self.profile = options['profile']
self.hash_inner_graph = options['hash_inner_graph']
# --Construct the destroy map--
if self.inplace:
n_outs = (len(output_states) + len(outputs) +
len(non_numeric_output_states))
for idx in xrange(n_outs):
self.destroy_map[idx] = [idx + 1 + len(inputs)]
for idx in xrange(len(outputs)):
self.destroy_map[idx] = [idx + 1]
# --Decide on the default mode--
mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given
......@@ -99,14 +89,9 @@ class ScanOp(PureOp):
self.name = 'scan_fn'
def make_node(self, *inputs):
out_types = []
out_types.extend(
[out_state.type() for out_state in self.output_states])
out_types.extend(
[out.type() for out in self.outputs])
out_types.extend(
[non_numeric_out_state.type() for non_numeric_out_state in
self.non_numeric_output_states])
# Checking if arguments are of the right type is done in the scan
# function
out_types = [out.type() for out in self.outputs]
return Apply(self, inputs, out_types)
def __eq__(self, other):
......@@ -119,24 +104,27 @@ class ScanOp(PureOp):
return False
elif not len(self.outputs) == len(other.outputs):
return False
elif self.info != other.info:
if self.mintals != other.mintaps:
return False
else:
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for self_in, other_in in izip(self.inputs, other.inputs):
if self_in.type != other_in.type:
# Check if the number of different types of arguments is the same
diff_args = ['inputs', 'outputs', 'lengths', 'mintaps', 'switches']
for arg in diff_args:
if len(getattr(self, arg)) != len(getattr(other, arg)):
return False
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
for x, y in izip(self.inputs, other.inputs):
if x.type != y.type:
return False
for x, y in izip(self.lengths, other.lengths):
if x.type != y.type:
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
s_inputs = [self.t] + self.inputs + self.lengths + self.switches
o_inputs = [other.t] + other.inputs + other.lengths + other.switches
givens = dict(izip(s_inputs, o_inputs))
# This part might be slow
for x, y in izip(self.outputs, other.outputs):
if not gof.graph.is_same_graph(x, y, givens=givens):
return False
return True
def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论