提交 b27500ce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.utils._multi to apply_across_args

上级 4807ebf9
......@@ -30,7 +30,12 @@ from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray
from theano.printing import pprint
from theano.utils import _multi, difference, from_return_values, to_return_values
from theano.utils import (
apply_across_args,
difference,
from_return_values,
to_return_values,
)
builtin_bool = bool
......@@ -860,11 +865,11 @@ Scalar.Constant = ScalarConstant
# Easy constructors
ints = _multi(int64)
floats = _multi(float64)
complexs = _multi(complex128)
complexs64 = _multi(complex64)
complexs128 = _multi(complex128)
ints = apply_across_args(int64)
floats = apply_across_args(float64)
complexs = apply_across_args(complex128)
complexs64 = apply_across_args(complex64)
complexs128 = apply_across_args(complex128)
def upcast_out(*types):
......
......@@ -32,7 +32,7 @@ from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst
from theano.tensor.utils import _pack
from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
from theano.utils import _multi
from theano.utils import apply_across_args
_logger = logging.getLogger("theano.tensor.basic")
......@@ -713,7 +713,7 @@ def scalar(name=None, dtype=None):
return type(name)
scalars, fscalars, dscalars, iscalars, lscalars = _multi(
scalars, fscalars, dscalars, iscalars, lscalars = apply_across_args(
scalar, fscalar, dscalar, iscalar, lscalar
)
......@@ -751,7 +751,7 @@ def vector(name=None, dtype=None):
return type(name)
vectors, fvectors, dvectors, ivectors, lvectors = _multi(
vectors, fvectors, dvectors, ivectors, lvectors = apply_across_args(
vector, fvector, dvector, ivector, lvector
)
......@@ -786,7 +786,7 @@ def matrix(name=None, dtype=None):
return type(name)
matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(
matrices, fmatrices, dmatrices, imatrices, lmatrices = apply_across_args(
matrix, fmatrix, dmatrix, imatrix, lmatrix
)
......@@ -821,7 +821,7 @@ def row(name=None, dtype=None):
return type(name)
rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
rows, frows, drows, irows, lrows = apply_across_args(row, frow, drow, irow, lrow)
ccol = TensorType("complex64", (False, True))
zcol = TensorType("complex128", (False, True))
......@@ -850,7 +850,7 @@ def col(name=None, dtype=None):
return type(name)
cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
cols, fcols, dcols, icols, lcols = apply_across_args(col, fcol, dcol, icol, lcol)
ctensor3 = TensorType("complex64", ((False,) * 3))
ztensor3 = TensorType("complex128", ((False,) * 3))
......@@ -879,7 +879,7 @@ def tensor3(name=None, dtype=None):
return type(name)
tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = _multi(
tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = apply_across_args(
tensor3, ftensor3, dtensor3, itensor3, ltensor3
)
......@@ -910,7 +910,7 @@ def tensor4(name=None, dtype=None):
return type(name)
tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = _multi(
tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = apply_across_args(
tensor4, ftensor4, dtensor4, itensor4, ltensor4
)
......@@ -941,7 +941,7 @@ def tensor5(name=None, dtype=None):
return type(name)
tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = _multi(
tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = apply_across_args(
tensor5, ftensor5, dtensor5, itensor5, ltensor5
)
......@@ -972,7 +972,7 @@ def tensor6(name=None, dtype=None):
return type(name)
tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = _multi(
tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = apply_across_args(
tensor6, ftensor6, dtensor6, itensor6, ltensor6
)
......@@ -1003,7 +1003,7 @@ def tensor7(name=None, dtype=None):
return type(name)
tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = _multi(
tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
tensor7, ftensor7, dtensor7, itensor7, ltensor7
)
......
......@@ -393,7 +393,7 @@ class NoDuplicateOptWarningFilter(logging.Filter):
return True
def _multi(*fns):
def apply_across_args(*fns):
"""Create new functions that distributes the wrapped functions across iterable arguments.
For example, a function, `fn`, that uses this decorator satisfies
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论