提交 3b1e6048 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added some meat to the scan op skeleton

上级 07f12948
......@@ -38,35 +38,130 @@ from scan_utils import safe_new
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op')
class ScanOp(PureOp):
def __init__(self,
inputs,
input_states,
parameters,
non_numeric_states,
non_numeric_input_states,
non_numeric_output_states,
output_states,
outputs,
lengths,
mintaps,
name,
mode,
inplace,
gpu,
as_repeatUntil,
profile):
pass
options,
as_repeatUntil):
self.options = options # name/mode/inplace/gpu/profile
self.inputs = inputs
self.input_states = input_states
self.parameters = parameters
self.non_numeric_input_states = non_numeric_input_states
self.non_numeric_output_states = non_numeric_output_states
self.output_states = output_states
self.outputs = outputs
self.lengths = lengths
self.mintaps = mintaps
self.name = name
self.mode = mode
self.inplace = inplace
self.gpu = gpu
self.as_repeatUntil = as_repeatUntil
self.profile = profile
if self.inplace:
n_outs = (len(output_states) + len(outputs) +
len(non_numeric_output_states))
for idx in xrange(n_outs):
self.destroy_map[idx] = [idx + 1 + len(inputs)]
mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given
# op will be counted multiple times
if (self.mode is None and
isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(
mode_instance)
self.mode_instance = mode_instance
if self.name:
self.mode_instance.message = self.name + " sub profile"
else:
self.mode_instance.message = "Scan sub profile"
else:
self.mode_instance = mode_instance
if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn'
def make_node(self, *inputs):
pass
out_types = []
out_types.extend(
[out_state.type() for out_state in self.output_states])
out_types.extend(
[out.type() for out in self.outputs])
out_types.extend(
[non_numeric_out_state.type() for non_numeric_out_state in
self.non_numeric_output_states])
return Apply(self, inputs, out_types)
def __eq__(self, other):
pass
# Check if we are dealing with same type of objects
if not type(self) == type(other):
return False
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
elif not len(self.inputs) == len(other.inputs):
return False
elif not len(self.outputs) == len(other.outputs):
return False
elif self.info != other.info:
return False
else:
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for self_in, other_in in izip(self.inputs, other.inputs):
if self_in.type != other_in.type:
return False
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return True
def __str__(self):
pass
if self.gpu:
gpu_str = 'gpu'
else:
gpu_str = 'cpu'
if self.as_repeatUntil:
name = 'do_while'
else:
name = 'for'
if self.inplace:
aux_txt = '%s{inplace,%s,%s}' % (name, gpu_str, str(self.name))
else:
aux_txt = '%s{%s,%s}' % (name, gpu_str, str(self.name))
return aux_txt
def __hash__(self):
pass
return (hash(type(self)) ^
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info))
def make_thunk(self, node, storage_map, compute_map, no_recycling):
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论