提交 219428ba authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.scan_module

上级 c90ef03d
...@@ -46,11 +46,13 @@ import logging ...@@ -46,11 +46,13 @@ import logging
import numpy as np import numpy as np
import theano.tensor as tt
from collections import OrderedDict from collections import OrderedDict
from six import integer_types from six import integer_types
from theano import compile, gof, tensor, config from theano import compile, gof, config
from theano.compile import SharedVariable, function, ops from theano.compile import SharedVariable, function, ops
from theano.tensor import opt from theano.tensor import opt
from theano.updates import OrderedUpdates from theano.updates import OrderedUpdates
...@@ -59,7 +61,6 @@ from theano.gof.utils import TestValueError ...@@ -59,7 +61,6 @@ from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse from theano.scan_module.scan_utils import safe_new, traverse
# Logging function for sending warning or info
_logger = logging.getLogger("theano.scan_module.scan") _logger = logging.getLogger("theano.scan_module.scan")
...@@ -142,11 +143,11 @@ def scan( ...@@ -142,11 +143,11 @@ def scan(
.. code-block:: python .. code-block:: python
import theano.tensor as TT import theano.tensor as tt
W = TT.matrix() W = tt.matrix()
W_2 = W**2 W_2 = W**2
def f(x): def f(x):
return TT.dot(x,W_2) return tt.dot(x,W_2)
The function is expected to return two things. One is a list of The function is expected to return two things. One is a list of
outputs ordered in the same order as ``outputs_info``, with the outputs ordered in the same order as ``outputs_info``, with the
...@@ -374,7 +375,7 @@ def scan( ...@@ -374,7 +375,7 @@ def scan(
non_seqs = [] non_seqs = []
for elem in wrap_into_list(non_sequences): for elem in wrap_into_list(non_sequences):
if not isinstance(elem, gof.Variable): if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem)) non_seqs.append(tt.as_tensor_variable(elem))
else: else:
non_seqs.append(elem) non_seqs.append(elem)
...@@ -389,11 +390,11 @@ def scan( ...@@ -389,11 +390,11 @@ def scan(
else: else:
try: try:
n_fixed_steps = opt.get_scalar_constant_value(n_steps) n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError: except tt.NotScalarConstantError:
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in tensor.integer_dtypes: if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in tt.integer_dtypes:
raise ValueError( raise ValueError(
" n_steps must be an int. dtype provided " "is %s" % n_steps.dtype " n_steps must be an int. dtype provided " "is %s" % n_steps.dtype
) )
...@@ -517,7 +518,7 @@ def scan( ...@@ -517,7 +518,7 @@ def scan(
# If not we need to use copies, that will be replaced at # If not we need to use copies, that will be replaced at
# each frame by the corresponding slice # each frame by the corresponding slice
actual_slice = seq["input"][k - mintap_proxy] actual_slice = seq["input"][k - mintap_proxy]
_seq_val = tensor.as_tensor_variable(seq["input"]) _seq_val = tt.as_tensor_variable(seq["input"])
_seq_val_slice = _seq_val[k - mintap_proxy] _seq_val_slice = _seq_val[k - mintap_proxy]
nw_slice = _seq_val_slice.type() nw_slice = _seq_val_slice.type()
...@@ -579,7 +580,7 @@ def scan( ...@@ -579,7 +580,7 @@ def scan(
if not scan_utils.isNaN_or_Inf_or_None(n_steps): if not scan_utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered # ^ N_steps should also be considered
lengths_vec.append(tensor.as_tensor(n_steps)) lengths_vec.append(tt.as_tensor(n_steps))
if len(lengths_vec) == 0: if len(lengths_vec) == 0:
# ^ No information about the number of steps # ^ No information about the number of steps
...@@ -595,9 +596,9 @@ def scan( ...@@ -595,9 +596,9 @@ def scan(
if scan_utils.isNaN_or_Inf_or_None(n_steps): if scan_utils.isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0] actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]: for contestant in lengths_vec[1:]:
actual_n_steps = tensor.minimum(actual_n_steps, contestant) actual_n_steps = tt.minimum(actual_n_steps, contestant)
else: else:
actual_n_steps = tensor.as_tensor(n_steps) actual_n_steps = tt.as_tensor(n_steps)
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions : # Conventions :
...@@ -644,10 +645,10 @@ def scan( ...@@ -644,10 +645,10 @@ def scan(
if init_out.get("taps", None) == [-1]: if init_out.get("taps", None) == [-1]:
actual_arg = init_out["initial"] actual_arg = init_out["initial"]
if not isinstance(actual_arg, tensor.Variable): if not isinstance(actual_arg, tt.Variable):
actual_arg = tensor.as_tensor_variable(actual_arg) actual_arg = tt.as_tensor_variable(actual_arg)
arg = safe_new(actual_arg) arg = safe_new(actual_arg)
if isinstance(arg, tensor.Constant): if isinstance(arg, tt.Constant):
# safe new returns a clone of the constants, but that is not # safe new returns a clone of the constants, but that is not
# what we need for initial states # what we need for initial states
arg = arg.type() arg = arg.type()
...@@ -673,7 +674,7 @@ def scan( ...@@ -673,7 +674,7 @@ def scan(
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand_empty( scan_utils.expand_empty(
tensor.unbroadcast(tensor.shape_padleft(actual_arg), 0), tt.unbroadcast(tt.shape_padleft(actual_arg), 0),
actual_n_steps, actual_n_steps,
) )
) )
...@@ -706,7 +707,7 @@ def scan( ...@@ -706,7 +707,7 @@ def scan(
for k in init_out["taps"]: for k in init_out["taps"]:
# create a new slice # create a new slice
actual_nw_slice = init_out["initial"][k + mintap] actual_nw_slice = init_out["initial"][k + mintap]
_init_out_var = tensor.as_tensor_variable(init_out["initial"]) _init_out_var = tt.as_tensor_variable(init_out["initial"])
_init_out_var_slice = _init_out_var[k + mintap] _init_out_var_slice = _init_out_var[k + mintap]
nw_slice = _init_out_var_slice.type() nw_slice = _init_out_var_slice.type()
...@@ -779,9 +780,7 @@ def scan( ...@@ -779,9 +780,7 @@ def scan(
dummy_args = [ dummy_args = [
arg arg
for arg in args for arg in args
if ( if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
] ]
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
...@@ -814,10 +813,10 @@ def scan( ...@@ -814,10 +813,10 @@ def scan(
# this will represent only a slice and it will have one # this will represent only a slice and it will have one
# dimension less. # dimension less.
if ( if (
isinstance(inner_out.type, tensor.TensorType) isinstance(inner_out.type, tt.TensorType)
and return_steps.get(pos, 0) != 1 and return_steps.get(pos, 0) != 1
): ):
outputs[pos] = tensor.unbroadcast(tensor.shape_padleft(inner_out), 0) outputs[pos] = tt.unbroadcast(tt.shape_padleft(inner_out), 0)
if return_list is not True and len(outputs) == 1: if return_list is not True and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
...@@ -931,11 +930,11 @@ def scan( ...@@ -931,11 +930,11 @@ def scan(
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand_empty( scan_utils.expand_empty(
tensor.unbroadcast(tensor.shape_padleft(input.variable), 0), tt.unbroadcast(tt.shape_padleft(input.variable), 0),
actual_n_steps, actual_n_steps,
) )
) )
tensor_update = tensor.as_tensor_variable(input.update) tensor_update = tt.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update) sit_sot_inner_outputs.append(tensor_update)
# Not that pos is not a negative index. The sign of pos is used # Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the # as a flag to indicate if this output should be part of the
...@@ -975,18 +974,14 @@ def scan( ...@@ -975,18 +974,14 @@ def scan(
other_scan_args += [ other_scan_args += [
arg arg
for arg in non_seqs for arg in non_seqs
if ( if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
] ]
# Step 5.6 all shared variables with no update rules # Step 5.6 all shared variables with no update rules
other_inner_args += [ other_inner_args += [
safe_new(arg, "_copy") safe_new(arg, "_copy")
for arg in non_seqs for arg in non_seqs
if ( if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
] ]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args))) givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
...@@ -1063,7 +1058,7 @@ def scan( ...@@ -1063,7 +1058,7 @@ def scan(
for w, w_copy in givens.items(): for w, w_copy in givens.items():
if isinstance(w.type, gpuarray.GpuArrayType) and isinstance( if isinstance(w.type, gpuarray.GpuArrayType) and isinstance(
w_copy.type, tensor.TensorType w_copy.type, tt.TensorType
): ):
for o in inner_outs: for o in inner_outs:
new_givens = traverse(o, w, w_copy, new_givens) new_givens = traverse(o, w, w_copy, new_givens)
...@@ -1121,7 +1116,7 @@ def scan( ...@@ -1121,7 +1116,7 @@ def scan(
scan_inputs = [] scan_inputs = []
for arg in [actual_n_steps] + _scan_inputs: for arg in [actual_n_steps] + _scan_inputs:
try: try:
arg = tensor.as_tensor_variable(arg) arg = tt.as_tensor_variable(arg)
except TypeError: except TypeError:
# This happens for Random States for e.g. but it is a good way # This happens for Random States for e.g. but it is a good way
# to make sure all inputs are tensors. # to make sure all inputs are tensors.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论