提交 e701d8f8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.tensor.utils._pack to as_list

上级 b27500ce
...@@ -30,7 +30,7 @@ from theano.tensor import elemwise ...@@ -30,7 +30,7 @@ from theano.tensor import elemwise
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, _scal_elemwise from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, _scal_elemwise
from theano.tensor.type import TensorType, values_eq_approx_always_true from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
from theano.tensor.utils import _pack from theano.tensor.utils import as_list
from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
from theano.utils import apply_across_args from theano.utils import apply_across_args
...@@ -6317,7 +6317,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched): ...@@ -6317,7 +6317,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
# if 'axes' is a list, transpose a and b such that the summed axes of a # if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first. # are last and the summed axes of b are first.
else: else:
axes = [_pack(axes_) for axes_ in axes] axes = [as_list(axes_) for axes_ in axes]
if len(axes[0]) != len(axes[1]): if len(axes[0]) != len(axes[1]):
raise ValueError("Axes elements must have the same length.") raise ValueError("Axes elements must have the same length.")
......
...@@ -99,10 +99,8 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -99,10 +99,8 @@ def shape_of_variables(fgraph, input_shapes):
return l return l
def _pack(x): def as_list(x):
""" """Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try: try:
return list(x) return list(x)
except TypeError: except TypeError:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论