提交 595ec4b2 authored 作者: lamblin's avatar lamblin

Merge pull request #1009 from pascanur/scan_grad_dtype_issue

Scan grad dtype issue
...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op): ...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op):
deep_copy_op = DeepCopyOp() deep_copy_op = DeepCopyOp()
# List of Theano Types that one can add an extra dimension and for which
# Scan can deal with.
expandable_types = ()
...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type): ...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type):
def c_compile_args(self): def c_compile_args(self):
return [] return []
theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType # Register C code for ViewOp on CudaNdarrayType
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
...@@ -53,6 +53,7 @@ from theano.tensor import opt ...@@ -53,6 +53,7 @@ from theano.tensor import opt
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import Updates
from theano.compile import ops
import scan_op import scan_op
...@@ -843,17 +844,38 @@ def scan(fn, ...@@ -843,17 +844,38 @@ def scan(fn,
shared_scan_inputs = [] shared_scan_inputs = []
shared_inner_inputs = [] shared_inner_inputs = []
shared_inner_outputs = [] shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs: for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None: if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append(new_var) if isinstance(new_var.type, ops.expandable_types):
shared_scan_inputs.append(input.variable) sit_sot_inner_inputs.append(new_var)
shared_inner_outputs.append(input.update) sit_sot_scan_inputs.append(
givens[input.variable] = new_var scan_utils.expand(
n_shared_outs += 1 tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0),
actual_n_steps))
sit_sot_inner_outputs.append(input.update)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
## Step 5.4 Outputs with no taps used in the input ## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
...@@ -1041,10 +1063,20 @@ def scan(fn, ...@@ -1041,10 +1063,20 @@ def scan(fn,
nit_sot_rightOrder) nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder) scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder): for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx] if pos >= 0:
scan_out_list[pos] = _scan_out_list[idx]
else:
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -23,7 +23,8 @@ from theano import gof ...@@ -23,7 +23,8 @@ from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import deep_copy_op, optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op
import scan_op import scan_op
import scan_utils import scan_utils
...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'to move some computation fron scan ' 'to move some computation fron scan '
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x,y in outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)] zip(nd.inputs, outside_ins)]
nw_outer_node = nd.op.make_node(*outside_ins) nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
...@@ -681,14 +682,18 @@ class ScanSaveMem(gof.Optimizer): ...@@ -681,14 +682,18 @@ class ScanSaveMem(gof.Optimizer):
if (nw_inputs[offset + idx].owner and if (nw_inputs[offset + idx].owner and
isinstance(nw_inputs[offset + idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor) and tensor.IncSubtensor) and
isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)): isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0],
slice)):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = tensor.as_tensor_variable(val) cval = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i]) initl = tensor.as_tensor_variable(init_l[i])
tmp_idx = tensor.switch(cval < initl, tmp_idx = tensor.switch(cval < initl,
cval + initl, cval + initl,
cval - initl) cval - initl)
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp_idx) tmp = pre_greedy_local_optimizer(list_opt_slice,
tmp_idx)
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand(_nw_input, tmp) nw_input = scan_utils.expand(_nw_input, tmp)
......
...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value ...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value
_logger = logging.getLogger('theano.scan_utils') _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''): def safe_new(x, tag='', dtype=None):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name (old name + tag). This function is used type, but with a different name (old name + tag). This function is used
...@@ -46,12 +46,18 @@ def safe_new(x, tag=''): ...@@ -46,12 +46,18 @@ def safe_new(x, tag=''):
else: else:
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
return x.clone() if dtype and x.dtype != dtype:
return x.clone().astype(dtype)
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op, # TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail # making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable): elif isinstance(x, scalar.ScalarVariable):
nw_x = x.type() if dtype:
new_x = scalar.Scalar(dtype=dtype)()
else:
nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
return nw_x return nw_x
else: else:
...@@ -63,6 +69,8 @@ def safe_new(x, tag=''): ...@@ -63,6 +69,8 @@ def safe_new(x, tag=''):
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
...@@ -930,3 +938,34 @@ class scan_args(object): ...@@ -930,3 +938,34 @@ class scan_args(object):
'mit_sot_in_slices')): 'mit_sot_in_slices')):
getattr(res, attr).extend(getattr(other, attr)) getattr(res, attr).extend(getattr(other, attr))
return res return res
def forced_replace(out, x, y):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if out is None:
return None
def traverse(graph, x):
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
return []
else:
rval = []
for inp in graph.owner.inputs:
rval += traverse(inp, x)
return rval
to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace))
...@@ -1076,6 +1076,7 @@ class TensorType(Type): ...@@ -1076,6 +1076,7 @@ class TensorType(Type):
""" """
return numpy.zeros(shape, dtype=self.dtype) return numpy.zeros(shape, dtype=self.dtype)
theano.compile.ops.expandable_types += (TensorType,)
# Register TensorType C code for ViewOp. # Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
...@@ -390,8 +390,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -390,8 +390,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# Do not make the DimShuffle inplace as an optimization at the # Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace. # canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph. # The inplace will be reintroduced automatically later in the graph.
return [DimShuffle(gz.type.broadcastable, grad_order)( if 'int' in inp[0].dtype:
Elemwise(scalar.identity)(gz))] return [theano.tensor.zeros_like(inp[0],
dtype=theano.config.floatX)]
else:
return [DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz))]
class DimShufflePrinter: class DimShufflePrinter:
......
...@@ -256,7 +256,9 @@ class RandomFunction(gof.Op): ...@@ -256,7 +256,9 @@ class RandomFunction(gof.Op):
out[0] = rval out[0] = rval
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
return [None for i in inputs] return [theano.gradient.grad_undefined(self, k, inp,
'No gradient defined through raw random numbers op')
for k, inp in enumerate(inputs)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None for i in eval_points] return [None for i in eval_points]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论