提交 595ec4b2 authored 作者: lamblin's avatar lamblin

Merge pull request #1009 from pascanur/scan_grad_dtype_issue

Scan grad dtype issue
...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op): ...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op):
deep_copy_op = DeepCopyOp() deep_copy_op = DeepCopyOp()
# List of Theano Types that one can add an extra dimension and for which
# Scan can deal with.
expandable_types = ()
...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type): ...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type):
def c_compile_args(self): def c_compile_args(self):
return [] return []
theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType # Register C code for ViewOp on CudaNdarrayType
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
...@@ -53,6 +53,7 @@ from theano.tensor import opt ...@@ -53,6 +53,7 @@ from theano.tensor import opt
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import Updates
from theano.compile import ops
import scan_op import scan_op
...@@ -843,17 +844,38 @@ def scan(fn, ...@@ -843,17 +844,38 @@ def scan(fn,
shared_scan_inputs = [] shared_scan_inputs = []
shared_inner_inputs = [] shared_inner_inputs = []
shared_inner_outputs = [] shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs: for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None: if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append(new_var) if isinstance(new_var.type, ops.expandable_types):
shared_scan_inputs.append(input.variable) sit_sot_inner_inputs.append(new_var)
shared_inner_outputs.append(input.update) sit_sot_scan_inputs.append(
givens[input.variable] = new_var scan_utils.expand(
n_shared_outs += 1 tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0),
actual_n_steps))
sit_sot_inner_outputs.append(input.update)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
## Step 5.4 Outputs with no taps used in the input ## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
...@@ -1041,10 +1063,20 @@ def scan(fn, ...@@ -1041,10 +1063,20 @@ def scan(fn,
nit_sot_rightOrder) nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder) scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder): for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx] if pos >= 0:
scan_out_list[pos] = _scan_out_list[idx]
else:
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType ...@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats
import scan_utils import scan_utils
from scan_utils import safe_new from scan_utils import safe_new, forced_replace
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op') _logger = logging.getLogger('theano.scan_module.scan_op')
...@@ -259,7 +259,7 @@ class Scan(PureOp): ...@@ -259,7 +259,7 @@ class Scan(PureOp):
for idx, (inner_seq, outer_seq) in enumerate( for idx, (inner_seq, outer_seq) in enumerate(
zip(self.inner_seqs(self.inputs), zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))): self.outer_seqs(inputs))):
if inner_seq.type.dtype != outer_seq[idx].type.dtype: if inner_seq.type.dtype != outer_seq[0].type.dtype:
assert isinstance(idx, int) assert isinstance(idx, int)
raise ValueError(err_msg1 % ('sequence', raise ValueError(err_msg1 % ('sequence',
...@@ -292,8 +292,11 @@ class Scan(PureOp): ...@@ -292,8 +292,11 @@ class Scan(PureOp):
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
outer_mitmot.type.ndim,
str(inner_mitmot[ipos + k]), str(inner_mitmot[ipos + k]),
inner_mitmot[ipos + k].type.dtype)) inner_mitmot[ipos +
k].type.dtype,
inner_mitmot[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
for k in xrange(len(otaps)): for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype != \ if (inner_mitmot_outs[opos + k].type.dtype != \
...@@ -304,7 +307,9 @@ class Scan(PureOp): ...@@ -304,7 +307,9 @@ class Scan(PureOp):
(str(outer_mitmot), (str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
inner_mitmot_outs[opos + k].type.dtype)) outer_mitmot.ndim,
inner_mitmot_outs[opos + k].type.dtype,
inner_mitmot_outs[opos + k].ndim))
opos += len(otaps) opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs)) argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot # Same checks as above but for outputs of type mit_sot
...@@ -329,14 +334,14 @@ class Scan(PureOp): ...@@ -329,14 +334,14 @@ class Scan(PureOp):
inner_mitsots[ipos + k].type.ndim)) inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1): inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitsot), (str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
outer_mitsot.type.ndim, outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype, inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim)) inner_mitsot_out.type.ndim))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -348,22 +353,22 @@ class Scan(PureOp): ...@@ -348,22 +353,22 @@ class Scan(PureOp):
inner_sitsot.ndim != outer_sitsot.ndim - 1): inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.type.dtype, inner_sitsot.type.dtype,
inner_sitsot.type.ndim)) inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1): inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype, inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim)) inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
...@@ -397,9 +402,7 @@ class Scan(PureOp): ...@@ -397,9 +402,7 @@ class Scan(PureOp):
for inner_nonseq, outer_nonseq in zip( for inner_nonseq, outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)): self.outer_non_seqs(inputs)):
if (inner_nonseq.type.dtype != outer_nonseq.type.dtype or if inner_nonseq.type != outer_nonseq.type:
inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not' raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') % ' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq))) (str(outer_nonseq), str(inner_nonseq)))
...@@ -1194,198 +1197,268 @@ class Scan(PureOp): ...@@ -1194,198 +1197,268 @@ class Scan(PureOp):
for o, x in izip(node.outputs, scan_outs)] for o, x in izip(node.outputs, scan_outs)]
return scan_outs return scan_outs
### GRAD FUNCTION def get_input_pos(self, output_index):
def grad(self, args, g_outs): ipos = self.n_seqs
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_pos(self, input_index):
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(itaps) > ipos:
return opos
else:
opos += len(otaps)
ipos -= len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if len(taps) > ipos:
return opos
else:
opos += 1
ipos -= len(taps)
if ipos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
def connection_pattern(self, node):
# The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs]
for x in node.inputs[1:]]
def compute_gradient(y, g_y, diff_inputs):
rval = []
gmp = {}
consider_inps = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
for x in consider_inps:
try:
_gmp = gradient.grad_sources_inputs(
[(y, g_y)],
[x])
gmp[x] = _gmp[x]
except TypeError:
# It means the gradient is undefined (which implies
# is connected)
gmp[x] = x
return [gmp.get(p, None) for p in diff_inputs]
def _get_inner_outs(oidx):
s = 0
if self.n_mit_mot > 0:
e = len(self.mitmot_out_taps()[0])
else:
e = 1
for p in xrange(oidx):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return self.outputs[s:e]
# This discards information about whether incoming gradients are 0 def _get_inner_inps(iidx):
# or disconnected from the cost s = 0
# TODO: upgrade scan op to report disconnection correctly if self.n_seqs > 0:
def strip_disconnected(g): e = 1
if isinstance(g.type, DisconnectedType): else:
e = len(self.tap_array[0])
p = iidx
if node.inputs[iidx + 1] in self.outer_nitsot(node):
return None return None
return g if node.inputs[iidx + 1] in self.outer_non_seqs(node):
loc_idx = self.outer_non_seqs(node).index(
node.inputs[iidx + 1])
return [self.inner_non_seqs(self.inputs)[loc_idx]]
for p in xrange(iidx):
s = e
if p < self.n_seqs:
e += 1
elif p - self.n_seqs < len(self.tap_array):
e += len(self.tap_array[p - self.n_seqs])
else:
e += 1
return self.inputs[s:e]
for oidx, out in enumerate(node.outputs):
for iidx, inp in enumerate(node.inputs[1:]):
ols = _get_inner_outs(oidx)
ils = _get_inner_inps(iidx)
g_outs = [strip_disconnected(g) for g in g_outs] if ils is None:
# The gradient should be disconnected
connection_pattern[iidx + 1][oidx] = False
else:
for inner_out in ols:
# We check for the dtype because inner_out could be
# any Theano type like Generic or RandomState, for
# which we can not impose a dtype
if hasattr(inner_out, 'dtype'):
# Note that we do not care about the output of
# this compute gradient. We just care to see if
# it is None or not. (i.e. disconnected or not)
tmp = compute_gradient(
inner_out,
safe_new(inner_out, dtype='float64'),
ils)
else:
# It should be undefined not disconnected
tmp = ils
if any([x is not None for x in tmp]):
connection_pattern[iidx + 1][oidx] = True
return connection_pattern
# 1. forward pass - get the outputs after applying scan ### GRAD FUNCTION
scan_outputs = self(*args) def grad(self, inputs, dC_douts):
# 2. make sure they are given as a list outs = self(*inputs)
if not(type(scan_outputs) in (list, tuple)): if not isinstance(outs, (list, tuple)):
scan_outputs = [scan_outputs] outs = [outs]
# 3. un-group / unzip the inputs # `grad_step` equals the number of steps the original scan node has
# Note ! We don't want to use the actual same variable as the ones # done (if the original scan is a while loop than this number is the
# used by the original scan, rather create clones of them # length of the output sequence)
# We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot
if self.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0:
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0:
grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\
self.mintaps[self.n_mit_mot]
else:
grad_steps = inputs[0]
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs, '_grad') self.outputs)
self_inputs = rval[0] self_inputs = rval[0]
self_outputs = rval[1] self_outputs = rval[1]
#differentiable inputs
diff_inputs = (self.inner_seqs(self_inputs) +
self.inner_mitmot(self_inputs) +
self.inner_mitsot(self_inputs) +
self.inner_sitsot(self_inputs) +
self.inner_non_seqs(self_inputs))
diff_outputs = (self.inner_mitmot_outs(self_outputs) +
self.inner_mitsot_outs(self_outputs) +
self.inner_sitsot_outs(self_outputs) +
self.inner_nitsot_outs(self_outputs))
seqs = self_inputs[:self.n_seqs]
offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot,
self.n_mit_mot + self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset + n_ins_mit_sot]
offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset + self.n_sit_sot]
offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset + self.n_shared_outs]
out_offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_nit_sot +
self.n_sit_sot)
# shared variables as well as the condition
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset + self.n_shared_outs]
offset += self.n_shared_outs
other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs
diff_inputs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
other_args)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs( gmp = gradient.grad_sources_inputs(
[(y, g_y)], diff_inputs) [(y, g_y)],
[x for x in theano.gof.graph.inputs([y])
if x in diff_inputs])
return [gmp.get(p, None) for p in diff_inputs] return [gmp.get(p, None) for p in diff_inputs]
dC_dinps_t = [None for inp in diff_inputs]
# 6. clean the outputs (i.e. remove update rules) disconnected_dC_dinps_t = [True for inp in diff_inputs]
end = (self.n_mit_mot_outs + dC_dXts = []
self.n_mit_sot + Xts = []
self.n_sit_sot + for idx, Xt in enumerate(diff_outputs):
self.n_nit_sot)
clean_outputs = self_outputs[:end] # We are looking for x[t-1] for a given x[t]
g_outs_no_shared = g_outs[:end] if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type()
# 7.1. empty lists to hold gradients Xts.append(Xt_placeholder)
# List of slices from outputs (used to compute the gradients) if Xt not in self.inner_nitsot_outs(self_outputs):
inner_g_outs = [] # What we do here is loop through dC_douts and collect all
g_out_slices = [] # those that are connected to the specific one and do an
# List of outputs of the gradient function # upcast on all of their dtypes to get the dtype for this
inner_gfn_outs = [] # specific output. Deciding if the gradient with this
# slices of the input # specific previous step is defined or not is done somewhere
prev_inner_gfn_outs = [] # else.
zeros_like_diff_ins = [] dtypes = []
pos = (self.n_seqs + states = (self.inner_mitmot(self_inputs) +
n_ins_mit_mot + self.inner_mitsot(self_inputs) +
n_ins_mit_sot + self.inner_sitsot(self_inputs))
self.n_sit_sot)
offset = len(args) - len(other_args) - pos for pos, inp in enumerate(states):
# 7.2. generate variables to represent previous steps of g_outs if inp in theano.gof.graph.inputs([Xt]):
for idx, diff_in in enumerate(diff_inputs): oidx = self.get_output_pos(pos)
prev_gfn_out = safe_new(diff_in) if not isinstance(dC_douts[oidx].type,
if hasattr(diff_in, 'name') and diff_in.name: DisconnectedType):
prev_gfn_out.name = 'g_prev_' + diff_in.name dtypes.append(dC_douts[oidx].dtype)
else: if dtypes:
prev_gfn_out.name = 'g_prev_' + str(idx) new_dtype = theano.scalar.upcast(*dtypes)
prev_inner_gfn_outs.append(prev_gfn_out) else:
if idx < pos: new_dtype = theano.config.floatX
zeros_like_diff_ins.append(tensor.zeros_like(diff_in)) dC_dXt = safe_new(Xt, dtype=new_dtype)
else:
zeros_like_diff_ins.append(
tensor.zeros_like(args[idx + offset]))
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
if g_outs[dx] != None:
inner_g_out = safe_new(g_outs[dx][0])
else: else:
# We do not have a gradient on this output so we need a if isinstance(dC_douts[idx].type, DisconnectedType):
# placeholder, which for now has the same dtype as the continue
# output dC_dXt = safe_new(dC_douts[idx][0])
inner_g_out = safe_new(out) dC_dXts.append(dC_dXt)
### _dC_dinps_t = compute_gradient(Xt, dC_dXt)
#### I need to clip the gradient HERE !! for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None:
if g_outs_no_shared[dx]: dC_dinps_t[jdx] = _dC_dinps_t[jdx]
g_out_slices.append(g_outs_no_shared[dx][0]) elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
dC_dinps_t[dx] = tensor.zeros_like(diff_inputs[dx])
else: else:
g_out_slices.append(None) disconnected_dC_dinps_t[dx] = False
if getattr(out, 'name', None) is not None: for Xt, Xt_placeholder in zip(
inner_g_out.name = 'g_' + out.name diff_outputs[self.n_mit_mot_outs:],
Xts):
tmp = forced_replace(
dC_dinps_t[dx],
Xt,
Xt_placeholder)
dC_dinps_t[dx] = tmp
# construct dX_dtm1
dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(dC_dXts[opos].type())
if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
else: else:
inner_g_out.name = 'g_' + str(dx) dC_dXtm1s.append(x.type())
inner_g_outs.append(inner_g_out) for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
_g_out = inner_g_out dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
grad_outs = compute_gradient(out, _g_out) # Construct scan op
if not inner_gfn_outs: # Seqs
for idx, gfn_out in enumerate(grad_outs): outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
if idx >= self.n_seqs:
inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else:
inner_gfn_outs.append(None)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for i, (x, y) in enumerate(zip(grad_outs, inner_gfn_outs)):
if x and y:
inner_gfn_outs[i] = x + y
elif y:
inner_gfn_outs[i] = y
else:
inner_gfn_outs[i] = x
## 8. Mask the outputs that are not differentiable
# backwards pass
for i in xrange(len(inner_gfn_outs)):
if inner_gfn_outs[i] is None:
inner_gfn_outs[i] = tensor.zeros_like(diff_inputs[i])
## 9. Mask the g_outs that are Nones :
for i, out in enumerate(scan_outputs):
if g_outs[i] is None:
try:
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
except Exception:
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
## 10. Get your sequence in order for the scan:
n_seqs = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
inner_seqs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
inner_g_outs[offset:offset + self.n_nit_sot])
scan_seqs = [x[::-1] for x in args[1:self.n_seqs + 1]]
offset = 0
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx]) maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset + idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if maxtap < 0: if maxtap < 0:
dim_offset = abs(maxtap) dim_offset = abs(maxtap)
else: else:
...@@ -1397,126 +1470,187 @@ class Scan(PureOp): ...@@ -1397,126 +1470,187 @@ class Scan(PureOp):
-(maxtap - k + 1)][::-1] -(maxtap - k + 1)][::-1]
else: else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1] nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1]
if getattr(seq, 'name', None) is not None: outer_inp_seqs.append(nw_seq)
nw_seq.name = seq.name + '[%d:]' % k outer_inp_seqs += [
scan_seqs.append(nw_seq) x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts):
offset += self.n_mit_sot if not isinstance(x.type, DisconnectedType):
for idx in xrange(self.n_sit_sot): outer_inp_seqs.append(x[::-1])
seq = scan_outputs[offset + idx][:-1]
scan_seqs.append(seq[::-1]) outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)]
offset = (self.n_mit_mot_outs + outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
self.n_mit_sot +
self.n_sit_sot) inner_inp_seqs = self.inner_seqs(self_inputs)
scan_seqs += [x[::-1] for x in inner_inp_seqs += self.inner_mitmot(self_inputs)
g_outs[offset:offset + self.n_nit_sot]] inner_inp_seqs += self.inner_mitsot(self_inputs)
inner_inp_seqs += self.inner_sitsot(self_inputs)
scan_mit_mot = [] inner_inp_seqs += self.inner_nitsot_outs(dC_dXts)
inner_mit_mot = [] inner_inp_seqs += Xts
scan_mit_mot_outs = [] # mitmot
mit_mot_taps = [] outer_inp_mitmot = []
mit_mot_out_slices = [] outer_out_mitmot = []
inner_inp_mitmot = []
inner_out_mitmot = []
mitmot_inp_taps = []
mitmot_out_taps = []
type_outs = []
out_pos = 0 out_pos = 0
ins_pos = n_seqs
n_mit_mot_outs = 0
n_mit_mot_ins = 0
ins_pos = self.n_seqs ins_pos = self.n_seqs
n_mitmot_outs = 0
n_mitmot_inps = 0
for idx in xrange(self.n_mit_mot): for idx in xrange(self.n_mit_mot):
scan_mit_mot.append(g_outs[idx][::-1]) outer_inp_mitmot.append(dC_douts[idx][::-1])
mit_mot_taps.append([]) mitmot_inp_taps.append([])
mit_mot_out_slices.append([]) mitmot_out_taps.append([])
undefined = False
disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_mit_mot.append(inner_g_outs[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
mit_mot_taps[idx].append(\ mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
-self.mit_mot_out_slices[idx][jdx]) n_mitmot_inps += 1
n_mit_mot_ins += 1
out_pos += 1 out_pos += 1
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
scan_mit_mot_outs.append(\ inner_out_mitmot.append(dC_dinps_t[ins_pos])
inner_gfn_outs[ins_pos]) if not disconnected_dC_dinps_t[ins_pos]:
n_mit_mot_ins += 1 disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
n_mitmot_inps_ += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mitmot_outs += 1
mit_mot_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append(\ mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
-self.tap_array[idx][jdx]) if undefined:
type_outs.append('undefined')
elif disconnected:
type_outs.append('disconnected')
else:
type_outs.append('connected')
offset = self.n_mit_mot offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot): for idx in xrange(self.n_mit_sot):
mit_mot_taps.append([]) mitmot_inp_taps.append([])
mit_mot_out_slices.append([]) mitmot_out_taps.append([])
scan_mit_mot.append(g_outs[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot idx_tap = idx + self.n_mit_mot
inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
undefined = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
mit_mot_taps[idx + offset].append(\ inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mit_mot_out_slices[idx].append(\ mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos]) if not disconnected_dC_dinps_t[ins_pos]:
n_mit_mot_ins += 1 disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mitmot_outs += 1
inner_mit_mot.append(inner_g_outs[out_pos]) if undefined:
out_pos += 1 type_outs.append('undefined')
n_mit_mot_ins += 1 elif disconnected:
mit_mot_taps[idx + offset].append(0) type_outs.append('disconnected')
else:
type_outs.append('connected')
offset += self.n_mit_sot offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
mit_mot_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mit_mot_out_slices.append([1]) mitmot_out_taps.append([1])
scan_mit_mot.append(g_outs[idx + offset][::-1]) undefined = False
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos]) if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
inner_mit_mot += [inner_g_outs[out_pos], outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
prev_inner_gfn_outs[ins_pos]] else:
n_mit_mot_outs += 1 outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
if undefined:
type_outs.append('undefined')
elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
inner_inp_mitmot += [dC_dXts[out_pos],
dC_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1
out_pos += 1 out_pos += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_ins += 2 n_mitmot_inps += 2
n_nit_sot = self.n_seqs
scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs]
if self.truncate_gradient != -1: if self.truncate_gradient != -1:
do_steps = tensor.minimum(args[0], self.truncate_gradient) grad_steps = tensor.minimum(grad_steps, self.truncate_gradient)
else:
do_steps = args[0] n_nit_sot = self.n_seqs
offset = (self.n_seqs + inner_out_nitsot = dC_dinps_t[:self.n_seqs]
n_ins_mit_sot + inner_out_sitsot = dC_dinps_t[ins_pos:]
n_ins_mit_mot + for _p, vl in enumerate(inner_out_sitsot):
self.n_sit_sot) undefined = False
# Instead of shared outs use sit_sot for _sh in self.inner_shared(self_inputs):
n_sitsot_outs = len(prev_inner_gfn_outs[offset:]) if _sh in gof.graph.inputs([vl]):
scan_sitsot_ins = prev_inner_gfn_outs[offset:] undefined = True
scan_sitsot_init = [] if undefined:
for x in zeros_like_diff_ins[offset:]: type_outs.append('undefined')
shapes = [x.shape[i] for i in xrange(x.ndim)] elif disconnected_dC_dinps_t[_p + ins_pos]:
empty = tensor.zeros([do_steps + 1] + shapes, type_outs.append('disconnected')
dtype=x.dtype) else:
scan_sitsot_init.append(empty) type_outs.append('connected')
scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in for _p, vl in enumerate(inner_out_nitsot):
undefined = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [
tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)],
dtype=y.dtype)
for y, x in zip(inner_inp_sitsot,
self.outer_non_seqs(inputs))]
n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in
xrange(n_sitsot_outs)] xrange(n_sitsot_outs)]
info = {} info = {}
info['n_seqs'] = n_seqs info['n_seqs'] = len(outer_inp_seqs)
info['n_mit_sot'] = 0 info['n_mit_sot'] = 0
info['tap_array'] = tap_array info['tap_array'] = new_tap_array
info['gpu'] = False info['gpu'] = False
n_mit_mot = (self.n_mit_mot + info['n_mit_mot'] = len(outer_inp_mitmot)
self.n_mit_sot + info['n_mit_mot_outs'] = n_mitmot_outs
self.n_sit_sot) info['mit_mot_out_slices'] = mitmot_out_taps
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs info['n_shared_outs'] = 0
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while info['as_while'] = False
info['profile'] = self.profile info['profile'] = self.profile
info['destroy_map'] = {} info['destroy_map'] = {}
if self.name: if self.name:
...@@ -1524,70 +1658,96 @@ class Scan(PureOp): ...@@ -1524,70 +1658,96 @@ class Scan(PureOp):
else: else:
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
n_mit_sot = 0
n_sit_sot = 0
offset = (1 + outer_inputs = ([grad_steps] +
self.n_seqs + outer_inp_seqs +
self.n_mit_mot + outer_inp_mitmot +
self.n_mit_sot + outer_inp_sitsot +
self.n_sit_sot + [inputs[0] for x in xrange(n_nit_sot)] +
self.n_nit_sot + self.outer_shared(inputs) +
self.n_shared_outs) self.outer_non_seqs(inputs))
scan_inputs = ([do_steps] +
scan_seqs +
scan_mit_mot +
scan_sitsot_init +
old_scan_init +
[args[0] for x in xrange(n_nit_sot)] +
args[offset:])
offset = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_shared_outs)
inner_other_args = self_inputs[offset:] inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_seqs + inner_gfn_ins = (inner_inp_seqs +
inner_mit_mot + inner_inp_mitmot +
scan_sitsot_ins + inner_inp_sitsot +
old_scan_shared_ins + self.inner_shared(self_inputs) +
inner_other_args) self.inner_non_seqs(self_inputs))
inner_gfn_outs = (scan_mit_mot_outs + inner_gfn_outs = (inner_out_mitmot +
scan_sitsot_outs + inner_out_sitsot +
scan_nit_sot_outs + inner_out_nitsot)
old_scan_shared_outs)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info) local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*scan_inputs) outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [outputs] outputs = [outputs]
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [grad_undefined(self, 0, args[0], 'Number of steps')] gradients = [DisconnectedType()()]
offset = (self.n_mit_mot + offset = (self.n_mit_mot +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot + self.n_sit_sot +
n_sitsot_outs) n_sitsot_outs)
gradients += [x[::-1] for x in outputs[offset:offset + self.n_seqs]] for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])):
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1,
inputs[p + 1],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [x[::-1] for x in outputs[:end]] for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
start = len(gradients) start = len(gradients)
gradients += [ node = outs[0].owner
grad_undefined(self, x + start, args[x + start], for idx in xrange(self.n_shared_outs):
'Shared Variable with update') disconnected = True
for x in xrange(self.n_shared_outs)] connected_flags = self.connection_pattern(node)[idx+start]
for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and
connected):
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
else:
gradients.append(grad_undefined(
self, idx, inputs[idx],
'Shared Variable with update'))
start = len(gradients) start = len(gradients)
gradients += [ gradients += [DisconnectedType()()
grad_undefined(self, x + start, args[x + start],
'Dimension of memory buffer for output')
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
gradients += [x[-1] for x in outputs[begin:end]] for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
gradients.append(
grad_undefined(self,
p + begin + 1,
inputs[p + begin + 1],
'Depends on a shared variable'))
else:
gradients.append(x[-1])
return gradients return gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -23,7 +23,8 @@ from theano import gof ...@@ -23,7 +23,8 @@ from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import deep_copy_op, optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op
import scan_op import scan_op
import scan_utils import scan_utils
...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'to move some computation fron scan ' 'to move some computation fron scan '
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x,y in outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)] zip(nd.inputs, outside_ins)]
nw_outer_node = nd.op.make_node(*outside_ins) nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
...@@ -681,14 +682,18 @@ class ScanSaveMem(gof.Optimizer): ...@@ -681,14 +682,18 @@ class ScanSaveMem(gof.Optimizer):
if (nw_inputs[offset + idx].owner and if (nw_inputs[offset + idx].owner and
isinstance(nw_inputs[offset + idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor) and tensor.IncSubtensor) and
isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)): isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0],
slice)):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = tensor.as_tensor_variable(val) cval = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i]) initl = tensor.as_tensor_variable(init_l[i])
tmp_idx = tensor.switch(cval < initl, tmp_idx = tensor.switch(cval < initl,
cval + initl, cval + initl,
cval - initl) cval - initl)
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp_idx) tmp = pre_greedy_local_optimizer(list_opt_slice,
tmp_idx)
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand(_nw_input, tmp) nw_input = scan_utils.expand(_nw_input, tmp)
......
...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value ...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value
_logger = logging.getLogger('theano.scan_utils') _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''): def safe_new(x, tag='', dtype=None):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name (old name + tag). This function is used type, but with a different name (old name + tag). This function is used
...@@ -46,12 +46,18 @@ def safe_new(x, tag=''): ...@@ -46,12 +46,18 @@ def safe_new(x, tag=''):
else: else:
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
return x.clone() if dtype and x.dtype != dtype:
return x.clone().astype(dtype)
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op, # TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail # making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable): elif isinstance(x, scalar.ScalarVariable):
nw_x = x.type() if dtype:
new_x = scalar.Scalar(dtype=dtype)()
else:
nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
return nw_x return nw_x
else: else:
...@@ -63,6 +69,8 @@ def safe_new(x, tag=''): ...@@ -63,6 +69,8 @@ def safe_new(x, tag=''):
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
...@@ -930,3 +938,34 @@ class scan_args(object): ...@@ -930,3 +938,34 @@ class scan_args(object):
'mit_sot_in_slices')): 'mit_sot_in_slices')):
getattr(res, attr).extend(getattr(other, attr)) getattr(res, attr).extend(getattr(other, attr))
return res return res
def forced_replace(out, x, y):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if out is None:
return None
def traverse(graph, x):
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
return []
else:
rval = []
for inp in graph.owner.inputs:
rval += traverse(inp, x)
return rval
to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace))
...@@ -513,7 +513,7 @@ class T_Scan(unittest.TestCase): ...@@ -513,7 +513,7 @@ class T_Scan(unittest.TestCase):
def f_rnn(u_t, x_tm1, W_in, W): def f_rnn(u_t, x_tm1, W_in, W):
return (u_t * W_in + x_tm1 * W, return (u_t * W_in + x_tm1 * W,
tensor.cast(u_t+x_tm1, 'int64')) tensor.cast(u_t + x_tm1, 'int64'))
u = theano.tensor.fvector('u') u = theano.tensor.fvector('u')
x0 = theano.tensor.fscalar('x0') x0 = theano.tensor.fscalar('x0')
...@@ -561,7 +561,6 @@ class T_Scan(unittest.TestCase): ...@@ -561,7 +561,6 @@ class T_Scan(unittest.TestCase):
scan_node = scan_node[0] scan_node = scan_node[0]
assert scan_node.op.gpu assert scan_node.op.gpu
# simple rnn, one input, one state, weights for each; input/state # simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars; using shared variables # are vectors, weights are scalars; using shared variables
def test_one_sequence_one_output_weights_shared(self): def test_one_sequence_one_output_weights_shared(self):
...@@ -1124,6 +1123,29 @@ class T_Scan(unittest.TestCase): ...@@ -1124,6 +1123,29 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(W1.get_value(), numpy_W1) assert numpy.allclose(W1.get_value(), numpy_W1)
assert numpy.allclose(W2.get_value(), numpy_W2) assert numpy.allclose(W2.get_value(), numpy_W2)
def test_grad_dtype_change(self):
x = tensor.fscalar('x')
y = tensor.fscalar('y')
c = tensor.iscalar('c')
def inner_fn(cond, x, y):
new_cond = tensor.cast(tensor.switch(cond, x, y), 'int32')
new_x = tensor.switch(cond, tensor.nnet.sigmoid(y * x), x)
new_y = tensor.switch(cond, y, tensor.nnet.sigmoid(x))
return new_cond, new_x, new_y
values, _ = theano.scan(
inner_fn,
outputs_info=[c, x, y],
n_steps=10,
truncate_gradient=-1,
go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True)
# Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
def test_simple_shared_mrg_random(self): def test_simple_shared_mrg_random(self):
theano_rng = theano.sandbox.rng_mrg.MRG_RandomStreams(utt.fetch_seed()) theano_rng = theano.sandbox.rng_mrg.MRG_RandomStreams(utt.fetch_seed())
...@@ -1470,8 +1492,11 @@ class T_Scan(unittest.TestCase): ...@@ -1470,8 +1492,11 @@ class T_Scan(unittest.TestCase):
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False) go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1] vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
# y0 is actually not used in the computation of the cost
params = [u1, u2, x0, y0, W_in1] params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params) gparams = theano.grad(cost, params,
disconnected_inputs='ignore')
grad_fn = theano.function([u1, u2, x0, y0, W_in1], grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams, gparams,
updates=updates, updates=updates,
...@@ -1711,8 +1736,8 @@ class T_Scan(unittest.TestCase): ...@@ -1711,8 +1736,8 @@ class T_Scan(unittest.TestCase):
def f_rnn_cmpl(u_t, x_tm1, W_in): def f_rnn_cmpl(u_t, x_tm1, W_in):
trng1 = theano.tensor.shared_randomstreams.RandomStreams(123) trng1 = theano.tensor.shared_randomstreams.RandomStreams(123)
x_t = theano.dot(u_t, W_in) + x_tm1 + \ rnd_nb = trng1.uniform(low=-.1, high=.1)
trng1.uniform(low=-.1, high=.1) x_t = theano.dot(u_t, W_in) + x_tm1 + rnd_nb
x_t = theano.tensor.cast(x_t, dtype=theano.config.floatX) x_t = theano.tensor.cast(x_t, dtype=theano.config.floatX)
return x_t return x_t
...@@ -1874,8 +1899,8 @@ class T_Scan(unittest.TestCase): ...@@ -1874,8 +1899,8 @@ class T_Scan(unittest.TestCase):
def test_scan_extra_inputs_hessian(self): def test_scan_extra_inputs_hessian(self):
x = theano.tensor.vector('x') x = theano.tensor.vector('x')
A = theano.tensor.matrix('A') A = theano.tensor.matrix('A')
fc1 = theano.shared(0.5, name = 'fc1') fc1 = theano.shared(0.5, name='fc1')
fc2 = theano.shared(0.9, name = 'fc2') fc2 = theano.shared(0.9, name='fc2')
y = fc1 * theano.dot(x * x, theano.dot(A, x)) y = fc1 * theano.dot(x * x, theano.dot(A, x))
y.name = 'y' y.name = 'y'
gy = theano.tensor.grad(y, x) gy = theano.tensor.grad(y, x)
...@@ -2316,12 +2341,13 @@ class T_Scan(unittest.TestCase): ...@@ -2316,12 +2341,13 @@ class T_Scan(unittest.TestCase):
allow_input_downcast=True, mode=mode_with_opt) allow_input_downcast=True, mode=mode_with_opt)
self.assertTrue(numpy.allclose(f([1, 2, 3]), 2. / 3)) self.assertTrue(numpy.allclose(f([1, 2, 3]), 2. / 3))
#theano.printing.debugprint(f, print_type=True)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
# this new assert is here to test if scan_merging works .. # this new assert is here to test if scan_merging works ..
nb_scan = len([n for n in topo nb_scan = len([n for n in topo
if isinstance(n.op, theano.scan_module.scan_op.Scan)]) if isinstance(n.op, theano.scan_module.scan_op.Scan)])
self.assertTrue(nb_scan == 1) # For this to work we need an optimization that it will be pushed in
# a new pull request
self.assertTrue(nb_scan == 2)
nb_shape_i = len([n for n in topo nb_shape_i = len([n for n in topo
if isinstance(n.op, theano.tensor.opt.Shape_i)]) if isinstance(n.op, theano.tensor.opt.Shape_i)])
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
...@@ -2511,10 +2537,10 @@ class T_Scan(unittest.TestCase): ...@@ -2511,10 +2537,10 @@ class T_Scan(unittest.TestCase):
def rnn_fn(_u, _y, _W): def rnn_fn(_u, _y, _W):
srng = theano.tensor.shared_randomstreams.RandomStreams(seed) srng = theano.tensor.shared_randomstreams.RandomStreams(seed)
sl_o = theano.tensor.tanh(theano.tensor.dot(_W, (_u + _y + \ tmp_val = _u + _y + srng.uniform(size=v_h0.shape) *\
srng.uniform(size=v_h0.shape) * numpy.asarray(1e-6, dtype=floatX)
numpy.asarray(1e-6, dtype=floatX)))) sl_o = theano.tensor.tanh(theano.tensor.dot(_W, tmp_val))
return sl_o return sl_o, tmp_val
u = theano.tensor.matrix('U') u = theano.tensor.matrix('U')
h0 = theano.tensor.vector('h0') h0 = theano.tensor.vector('h0')
...@@ -2527,9 +2553,9 @@ class T_Scan(unittest.TestCase): ...@@ -2527,9 +2553,9 @@ class T_Scan(unittest.TestCase):
_W = theano.tensor.specify_shape(W, v_W.shape) _W = theano.tensor.specify_shape(W, v_W.shape)
_W.name = '_W' _W.name = '_W'
o, _ = theano.scan(rnn_fn, [o, _], _ = theano.scan(rnn_fn,
sequences=_u, sequences=_u,
outputs_info=_h0, outputs_info=[_h0, None],
non_sequences=_W, non_sequences=_W,
name='rnn_fn') name='rnn_fn')
o = o[-1] o = o[-1]
...@@ -3110,6 +3136,7 @@ class T_Scan(unittest.TestCase): ...@@ -3110,6 +3136,7 @@ class T_Scan(unittest.TestCase):
loss, loss,
no_default_updates=True, no_default_updates=True,
allow_input_downcast=True) allow_input_downcast=True)
gw, gx = tensor.grad(loss, [w, xinit]) gw, gx = tensor.grad(loss, [w, xinit])
grad_fn = theano.function([xinit, w], [gx, gw], grad_fn = theano.function([xinit, w], [gx, gw],
allow_input_downcast=True) allow_input_downcast=True)
...@@ -3135,6 +3162,20 @@ class T_Scan(unittest.TestCase): ...@@ -3135,6 +3162,20 @@ class T_Scan(unittest.TestCase):
raise Exception(theano.tensor.verify_grad.E_grad, raise Exception(theano.tensor.verify_grad.E_grad,
(max_err, 1e-2, max_err_pos)) (max_err, 1e-2, max_err_pos))
def test_grad_numeric_shared(self):
shared_var = theano.shared(numpy.float32(1.))
def inner_fn():
return [], {shared_var: shared_var + numpy.float32(1.)}
_, updates = theano.scan(inner_fn,
n_steps=10,
truncate_gradient=-1,
go_backwards=False)
cost = updates.values()[0]
g_sh = tensor.grad(cost, shared_var)
fgrad = theano.function([], g_sh)
assert fgrad() == 1
def test_rop_mitmot(self): def test_rop_mitmot(self):
# this test is a copy paste from the script given by Justin Bayer to # this test is a copy paste from the script given by Justin Bayer to
# reproduce this bug # reproduce this bug
...@@ -3188,17 +3229,17 @@ class T_Scan(unittest.TestCase): ...@@ -3188,17 +3229,17 @@ class T_Scan(unittest.TestCase):
Hp = tensor.Rop(d_cost_wrt_pars, pars, p) Hp = tensor.Rop(d_cost_wrt_pars, pars, p)
def test_seq_tap_bug_jeremiah(self): def test_seq_tap_bug_jeremiah(self):
inp = numpy.arange(10).reshape(-1,1).astype(theano.config.floatX) inp = numpy.arange(10).reshape(-1, 1).astype(theano.config.floatX)
exp_out = numpy.zeros((10,1)).astype(theano.config.floatX) exp_out = numpy.zeros((10, 1)).astype(theano.config.floatX)
exp_out[4:] = inp[:-4] exp_out[4:] = inp[:-4]
def onestep(x, x_tm4): def onestep(x, x_tm4):
return x, x_tm4 return x, x_tm4
seq = tensor.matrix() seq = tensor.matrix()
initial_value = theano.shared(numpy.zeros((4,1), initial_value = theano.shared(numpy.zeros((4, 1),
dtype=theano.config.floatX)) dtype=theano.config.floatX))
outputs_info = [{'initial' : initial_value, 'taps' : [-4]}, None] outputs_info = [{'initial': initial_value, 'taps': [-4]}, None]
results, updates = theano.scan(fn=onestep, results, updates = theano.scan(fn=onestep,
sequences=seq, sequences=seq,
outputs_info=outputs_info) outputs_info=outputs_info)
...@@ -3208,27 +3249,49 @@ class T_Scan(unittest.TestCase): ...@@ -3208,27 +3249,49 @@ class T_Scan(unittest.TestCase):
def test_borrow_bug_jeremiah(self): def test_borrow_bug_jeremiah(self):
# This test fails if scan uses wrongly the borrow flag # This test fails if scan uses wrongly the borrow flag
inp = numpy.arange(10).reshape(-1,1).astype(theano.config.floatX) inp = numpy.arange(10).reshape(-1, 1).astype(theano.config.floatX)
exp_out = numpy.zeros((10,1)).astype(theano.config.floatX) exp_out = numpy.zeros((10, 1)).astype(theano.config.floatX)
exp_out[4:] = inp[:-4] exp_out[4:] = inp[:-4]
def onestep(x, x_tm4): def onestep(x, x_tm4):
return x, x_tm4 return x, x_tm4
seq = tensor.matrix() seq = tensor.matrix()
initial_value = theano.shared(numpy.zeros((4,1), initial_value = theano.shared(numpy.zeros((4, 1),
dtype=theano.config.floatX)) dtype=theano.config.floatX))
outputs_info = [{'initial' : initial_value, 'taps' : [-4]}, None] outputs_info = [{'initial': initial_value, 'taps': [-4]}, None]
results, _ = theano.scan(fn=onestep, results, _ = theano.scan(fn=onestep,
sequences=seq, sequences=seq,
outputs_info=outputs_info) outputs_info=outputs_info)
sharedvar = theano.shared(numpy.zeros((1,1), sharedvar = theano.shared(numpy.zeros((1, 1),
dtype=theano.config.floatX)) dtype=theano.config.floatX))
updates = {sharedvar : results[0][-1:]} updates = {sharedvar: results[0][-1:]}
f = theano.function([seq], results[1], updates=updates) f = theano.function([seq], results[1], updates=updates)
assert numpy.all(exp_out == f(inp)) assert numpy.all(exp_out == f(inp))
def test_grad_connectivity_matrix(self):
def inner_fn(x_tm1, y_tm1, z_tm1):
x_tm1.name = 'x'
y_tm1.name = 'y'
z_tm1.name = 'z'
return x_tm1 ** 2, x_tm1 + y_tm1, x_tm1 + 1
x0 = tensor.vector('X')
y0 = tensor.vector('y0')
z0 = tensor.vector('Z')
[x, y, z], _ = theano.scan(inner_fn,
outputs_info=[x0, y0, z0],
n_steps=10)
cost = (x + y + z).sum()
gx0 = tensor.grad(cost, x0) # defined
gy0 = tensor.grad(cost, y0) # defined
self.assertRaises(ValueError, tensor.grad, cost, z0)
cost = x.sum()
self.assertRaises(ValueError, tensor.grad, cost, y0)
def test_speed(): def test_speed():
# #
# This function prints out the speed of very simple recurrent # This function prints out the speed of very simple recurrent
...@@ -3576,9 +3639,7 @@ if __name__ == '__main__': ...@@ -3576,9 +3639,7 @@ if __name__ == '__main__':
def test_compute_test_value(): def test_compute_test_value():
""" # Verify that test values can be used with scan.
Verify that test values can be used with scan.
"""
backup = theano.config.compute_test_value backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise' theano.config.compute_test_value = 'raise'
try: try:
...@@ -3590,7 +3651,7 @@ def test_compute_test_value(): ...@@ -3590,7 +3651,7 @@ def test_compute_test_value():
fn=lambda u, v: u + v, fn=lambda u, v: u + v,
sequences=[x, y]) sequences=[x, y])
assert not _ assert not _
z.name='z' z.name = 'z'
# The gradient computation used to crash before 6af465e. # The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x) g = tensor.grad(z.sum(), x)
#f = theano.function([x], g) #f = theano.function([x], g)
......
...@@ -1076,6 +1076,7 @@ class TensorType(Type): ...@@ -1076,6 +1076,7 @@ class TensorType(Type):
""" """
return numpy.zeros(shape, dtype=self.dtype) return numpy.zeros(shape, dtype=self.dtype)
theano.compile.ops.expandable_types += (TensorType,)
# Register TensorType C code for ViewOp. # Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
...@@ -390,8 +390,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -390,8 +390,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# Do not make the DimShuffle inplace as an optimization at the # Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace. # canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph. # The inplace will be reintroduced automatically later in the graph.
return [DimShuffle(gz.type.broadcastable, grad_order)( if 'int' in inp[0].dtype:
Elemwise(scalar.identity)(gz))] return [theano.tensor.zeros_like(inp[0],
dtype=theano.config.floatX)]
else:
return [DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz))]
class DimShufflePrinter: class DimShufflePrinter:
......
...@@ -256,7 +256,9 @@ class RandomFunction(gof.Op): ...@@ -256,7 +256,9 @@ class RandomFunction(gof.Op):
out[0] = rval out[0] = rval
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
return [None for i in inputs] return [theano.gradient.grad_undefined(self, k, inp,
'No gradient defined through raw random numbers op')
for k, inp in enumerate(inputs)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None for i in eval_points] return [None for i in eval_points]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论