提交 144e351a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Do some clean up and add some asserts to make sure we don't miss anything.

上级 5ed8d923
......@@ -217,10 +217,9 @@ class Scan(PureOp):
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.__class__(
broadcastable=tuple(var.broadcastable[:1])+\
tuple(as_var.broadcastable),
dtype=as_var.dtype)
tmp = as_var.type.clone(
broadcastable=(tuple(var.broadcastable[:1]) +
tuple(as_var.broadcastable)))
rval = tmp.filter_variable(rval)
return rval
......@@ -493,11 +492,11 @@ class Scan(PureOp):
return aux_txt
def __hash__(self):
return (hash(type(self)) ^
return hash((type(self),
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info))
self._hash_inner_graph,
scan_utils.hash_listsDictsTuples(self.info)))
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
......
......@@ -46,6 +46,7 @@ def safe_new(x, tag='', dtype=None):
nw_name = x.name + tag
else:
nw_name = None
if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype:
casted_x = x.astype(dtype)
......@@ -54,20 +55,14 @@ def safe_new(x, tag='', dtype=None):
return nwx
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable):
if dtype:
nw_x = scalar.get_scalar_type(dtype=dtype)()
else:
nw_x = x.type()
nw_x.name = nw_name
return nw_x
# at this point we should only have Variables
assert isinstance(x, theano.Variable)
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
......@@ -807,7 +802,7 @@ class scan_args(object):
def __init__(self, outer_inputs, outer_outputs,
_inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
rval = reconstruct_graph(_inner_inputs, _inner_outputs, '')
if info['as_while']:
self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1]
......@@ -911,6 +906,9 @@ class scan_args(object):
p += n_shared_outs
q += n_shared_outs
assert p == len(outer_outputs)
assert q == len(inner_outputs)
self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'gpua', 'as_while', 'profile', 'allow_gc'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论