提交 219428ba authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.scan_module

上级 c90ef03d
......@@ -46,11 +46,13 @@ import logging
import numpy as np
import theano.tensor as tt
from collections import OrderedDict
from six import integer_types
from theano import compile, gof, tensor, config
from theano import compile, gof, config
from theano.compile import SharedVariable, function, ops
from theano.tensor import opt
from theano.updates import OrderedUpdates
......@@ -59,7 +61,6 @@ from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse
# Logging function for sending warning or info
_logger = logging.getLogger("theano.scan_module.scan")
......@@ -142,11 +143,11 @@ def scan(
.. code-block:: python
import theano.tensor as TT
W = TT.matrix()
import theano.tensor as tt
W = tt.matrix()
W_2 = W**2
def f(x):
return TT.dot(x,W_2)
return tt.dot(x,W_2)
The function is expected to return two things. One is a list of
outputs ordered in the same order as ``outputs_info``, with the
......@@ -374,7 +375,7 @@ def scan(
non_seqs = []
for elem in wrap_into_list(non_sequences):
if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem))
non_seqs.append(tt.as_tensor_variable(elem))
else:
non_seqs.append(elem)
......@@ -389,11 +390,11 @@ def scan(
else:
try:
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
except tt.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in tensor.integer_dtypes:
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in tt.integer_dtypes:
raise ValueError(
" n_steps must be an int. dtype provided " "is %s" % n_steps.dtype
)
......@@ -517,7 +518,7 @@ def scan(
# If not we need to use copies, that will be replaced at
# each frame by the corresponding slice
actual_slice = seq["input"][k - mintap_proxy]
_seq_val = tensor.as_tensor_variable(seq["input"])
_seq_val = tt.as_tensor_variable(seq["input"])
_seq_val_slice = _seq_val[k - mintap_proxy]
nw_slice = _seq_val_slice.type()
......@@ -579,7 +580,7 @@ def scan(
if not scan_utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
lengths_vec.append(tensor.as_tensor(n_steps))
lengths_vec.append(tt.as_tensor(n_steps))
if len(lengths_vec) == 0:
# ^ No information about the number of steps
......@@ -595,9 +596,9 @@ def scan(
if scan_utils.isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = tensor.minimum(actual_n_steps, contestant)
actual_n_steps = tt.minimum(actual_n_steps, contestant)
else:
actual_n_steps = tensor.as_tensor(n_steps)
actual_n_steps = tt.as_tensor(n_steps)
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions :
......@@ -644,10 +645,10 @@ def scan(
if init_out.get("taps", None) == [-1]:
actual_arg = init_out["initial"]
if not isinstance(actual_arg, tensor.Variable):
actual_arg = tensor.as_tensor_variable(actual_arg)
if not isinstance(actual_arg, tt.Variable):
actual_arg = tt.as_tensor_variable(actual_arg)
arg = safe_new(actual_arg)
if isinstance(arg, tensor.Constant):
if isinstance(arg, tt.Constant):
# safe new returns a clone of the constants, but that is not
# what we need for initial states
arg = arg.type()
......@@ -673,7 +674,7 @@ def scan(
# defined in scan utils
sit_sot_scan_inputs.append(
scan_utils.expand_empty(
tensor.unbroadcast(tensor.shape_padleft(actual_arg), 0),
tt.unbroadcast(tt.shape_padleft(actual_arg), 0),
actual_n_steps,
)
)
......@@ -706,7 +707,7 @@ def scan(
for k in init_out["taps"]:
# create a new slice
actual_nw_slice = init_out["initial"][k + mintap]
_init_out_var = tensor.as_tensor_variable(init_out["initial"])
_init_out_var = tt.as_tensor_variable(init_out["initial"])
_init_out_var_slice = _init_out_var[k + mintap]
nw_slice = _init_out_var_slice.type()
......@@ -779,9 +780,7 @@ def scan(
dummy_args = [
arg
for arg in args
if (
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
......@@ -814,10 +813,10 @@ def scan(
# this will represent only a slice and it will have one
# dimension less.
if (
isinstance(inner_out.type, tensor.TensorType)
isinstance(inner_out.type, tt.TensorType)
and return_steps.get(pos, 0) != 1
):
outputs[pos] = tensor.unbroadcast(tensor.shape_padleft(inner_out), 0)
outputs[pos] = tt.unbroadcast(tt.shape_padleft(inner_out), 0)
if return_list is not True and len(outputs) == 1:
outputs = outputs[0]
......@@ -931,11 +930,11 @@ def scan(
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
scan_utils.expand_empty(
tensor.unbroadcast(tensor.shape_padleft(input.variable), 0),
tt.unbroadcast(tt.shape_padleft(input.variable), 0),
actual_n_steps,
)
)
tensor_update = tensor.as_tensor_variable(input.update)
tensor_update = tt.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
......@@ -975,18 +974,14 @@ def scan(
other_scan_args += [
arg
for arg in non_seqs
if (
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
]
# Step 5.6 all shared variables with no update rules
other_inner_args += [
safe_new(arg, "_copy")
for arg in non_seqs
if (
not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)
)
if (not isinstance(arg, SharedVariable) and not isinstance(arg, tt.Constant))
]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
......@@ -1063,7 +1058,7 @@ def scan(
for w, w_copy in givens.items():
if isinstance(w.type, gpuarray.GpuArrayType) and isinstance(
w_copy.type, tensor.TensorType
w_copy.type, tt.TensorType
):
for o in inner_outs:
new_givens = traverse(o, w, w_copy, new_givens)
......@@ -1121,7 +1116,7 @@ def scan(
scan_inputs = []
for arg in [actual_n_steps] + _scan_inputs:
try:
arg = tensor.as_tensor_variable(arg)
arg = tt.as_tensor_variable(arg)
except TypeError:
# This happens for Random States for e.g. but it is a good way
# to make sure all inputs are tensors.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论