提交 b27500ce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.utils._multi to apply_across_args

上级 4807ebf9
...@@ -30,7 +30,12 @@ from theano.gof.utils import MetaObject, MethodNotDefined ...@@ -30,7 +30,12 @@ from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
from theano.utils import _multi, difference, from_return_values, to_return_values from theano.utils import (
apply_across_args,
difference,
from_return_values,
to_return_values,
)
builtin_bool = bool builtin_bool = bool
...@@ -860,11 +865,11 @@ Scalar.Constant = ScalarConstant ...@@ -860,11 +865,11 @@ Scalar.Constant = ScalarConstant
# Easy constructors # Easy constructors
ints = _multi(int64) ints = apply_across_args(int64)
floats = _multi(float64) floats = apply_across_args(float64)
complexs = _multi(complex128) complexs = apply_across_args(complex128)
complexs64 = _multi(complex64) complexs64 = apply_across_args(complex64)
complexs128 = _multi(complex128) complexs128 = apply_across_args(complex128)
def upcast_out(*types): def upcast_out(*types):
......
...@@ -32,7 +32,7 @@ from theano.tensor.type import TensorType, values_eq_approx_always_true ...@@ -32,7 +32,7 @@ from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
from theano.tensor.utils import _pack from theano.tensor.utils import _pack
from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
from theano.utils import _multi from theano.utils import apply_across_args
_logger = logging.getLogger("theano.tensor.basic") _logger = logging.getLogger("theano.tensor.basic")
...@@ -713,7 +713,7 @@ def scalar(name=None, dtype=None): ...@@ -713,7 +713,7 @@ def scalar(name=None, dtype=None):
return type(name) return type(name)
scalars, fscalars, dscalars, iscalars, lscalars = _multi( scalars, fscalars, dscalars, iscalars, lscalars = apply_across_args(
scalar, fscalar, dscalar, iscalar, lscalar scalar, fscalar, dscalar, iscalar, lscalar
) )
...@@ -751,7 +751,7 @@ def vector(name=None, dtype=None): ...@@ -751,7 +751,7 @@ def vector(name=None, dtype=None):
return type(name) return type(name)
vectors, fvectors, dvectors, ivectors, lvectors = _multi( vectors, fvectors, dvectors, ivectors, lvectors = apply_across_args(
vector, fvector, dvector, ivector, lvector vector, fvector, dvector, ivector, lvector
) )
...@@ -786,7 +786,7 @@ def matrix(name=None, dtype=None): ...@@ -786,7 +786,7 @@ def matrix(name=None, dtype=None):
return type(name) return type(name)
matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi( matrices, fmatrices, dmatrices, imatrices, lmatrices = apply_across_args(
matrix, fmatrix, dmatrix, imatrix, lmatrix matrix, fmatrix, dmatrix, imatrix, lmatrix
) )
...@@ -821,7 +821,7 @@ def row(name=None, dtype=None): ...@@ -821,7 +821,7 @@ def row(name=None, dtype=None):
return type(name) return type(name)
rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow) rows, frows, drows, irows, lrows = apply_across_args(row, frow, drow, irow, lrow)
ccol = TensorType("complex64", (False, True)) ccol = TensorType("complex64", (False, True))
zcol = TensorType("complex128", (False, True)) zcol = TensorType("complex128", (False, True))
...@@ -850,7 +850,7 @@ def col(name=None, dtype=None): ...@@ -850,7 +850,7 @@ def col(name=None, dtype=None):
return type(name) return type(name)
cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol) cols, fcols, dcols, icols, lcols = apply_across_args(col, fcol, dcol, icol, lcol)
ctensor3 = TensorType("complex64", ((False,) * 3)) ctensor3 = TensorType("complex64", ((False,) * 3))
ztensor3 = TensorType("complex128", ((False,) * 3)) ztensor3 = TensorType("complex128", ((False,) * 3))
...@@ -879,7 +879,7 @@ def tensor3(name=None, dtype=None): ...@@ -879,7 +879,7 @@ def tensor3(name=None, dtype=None):
return type(name) return type(name)
tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = _multi( tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = apply_across_args(
tensor3, ftensor3, dtensor3, itensor3, ltensor3 tensor3, ftensor3, dtensor3, itensor3, ltensor3
) )
...@@ -910,7 +910,7 @@ def tensor4(name=None, dtype=None): ...@@ -910,7 +910,7 @@ def tensor4(name=None, dtype=None):
return type(name) return type(name)
tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = _multi( tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = apply_across_args(
tensor4, ftensor4, dtensor4, itensor4, ltensor4 tensor4, ftensor4, dtensor4, itensor4, ltensor4
) )
...@@ -941,7 +941,7 @@ def tensor5(name=None, dtype=None): ...@@ -941,7 +941,7 @@ def tensor5(name=None, dtype=None):
return type(name) return type(name)
tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = _multi( tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = apply_across_args(
tensor5, ftensor5, dtensor5, itensor5, ltensor5 tensor5, ftensor5, dtensor5, itensor5, ltensor5
) )
...@@ -972,7 +972,7 @@ def tensor6(name=None, dtype=None): ...@@ -972,7 +972,7 @@ def tensor6(name=None, dtype=None):
return type(name) return type(name)
tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = _multi( tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = apply_across_args(
tensor6, ftensor6, dtensor6, itensor6, ltensor6 tensor6, ftensor6, dtensor6, itensor6, ltensor6
) )
...@@ -1003,7 +1003,7 @@ def tensor7(name=None, dtype=None): ...@@ -1003,7 +1003,7 @@ def tensor7(name=None, dtype=None):
return type(name) return type(name)
tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = _multi( tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
tensor7, ftensor7, dtensor7, itensor7, ltensor7 tensor7, ftensor7, dtensor7, itensor7, ltensor7
) )
......
...@@ -393,7 +393,7 @@ class NoDuplicateOptWarningFilter(logging.Filter): ...@@ -393,7 +393,7 @@ class NoDuplicateOptWarningFilter(logging.Filter):
return True return True
def _multi(*fns): def apply_across_args(*fns):
"""Create new functions that distributes the wrapped functions across iterable arguments. """Create new functions that distributes the wrapped functions across iterable arguments.
For example, a function, `fn`, that uses this decorator satisfies For example, a function, `fn`, that uses this decorator satisfies
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论