提交 4807ebf9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move theano.tensor/scalar utility functions to more appropriate modules

上级 deb437d8
...@@ -13,7 +13,6 @@ you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar! ...@@ -13,7 +13,6 @@ you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
import math import math
from collections.abc import Callable from collections.abc import Callable
from copy import copy from copy import copy
from functools import partial
from itertools import chain from itertools import chain
from textwrap import dedent from textwrap import dedent
...@@ -31,7 +30,7 @@ from theano.gof.utils import MetaObject, MethodNotDefined ...@@ -31,7 +30,7 @@ from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
from theano.utils import difference, from_return_values, to_return_values from theano.utils import _multi, difference, from_return_values, to_return_values
builtin_bool = bool builtin_bool = bool
...@@ -861,20 +860,6 @@ Scalar.Constant = ScalarConstant ...@@ -861,20 +860,6 @@ Scalar.Constant = ScalarConstant
# Easy constructors # Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns[0])
else:
return [partial(f2, f) for f in fns]
ints = _multi(int64) ints = _multi(int64)
floats = _multi(float64) floats = _multi(float64)
complexs = _multi(complex128) complexs = _multi(complex128)
......
...@@ -4,7 +4,6 @@ import builtins ...@@ -4,7 +4,6 @@ import builtins
import logging import logging
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial
import numpy as np import numpy as np
...@@ -28,10 +27,12 @@ from theano.scalar import int32 ...@@ -28,10 +27,12 @@ from theano.scalar import int32
from theano.tensor import elemwise from theano.tensor import elemwise
# set up the external interface # set up the external interface
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, _scal_elemwise
from theano.tensor.type import TensorType, values_eq_approx_always_true from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
from theano.tensor.utils import _pack
from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
from theano.utils import _multi
_logger = logging.getLogger("theano.tensor.basic") _logger = logging.getLogger("theano.tensor.basic")
...@@ -685,27 +686,6 @@ def tensor(*args, **kwargs): ...@@ -685,27 +686,6 @@ def tensor(*args, **kwargs):
return TensorType(*args, **kwargs)(name=name) return TensorType(*args, **kwargs)(name=name)
def _multi(*fns):
def f2(f, *names):
if names and isinstance(names[0], int):
if names == 1:
return f()
else:
return [f() for i in range(names[0])]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
cscalar = TensorType("complex64", ()) cscalar = TensorType("complex64", ())
zscalar = TensorType("complex128", ()) zscalar = TensorType("complex128", ())
fscalar = TensorType("float32", ()) fscalar = TensorType("float32", ())
...@@ -1037,129 +1017,6 @@ elemwise.TensorType = TensorType ...@@ -1037,129 +1017,6 @@ elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant elemwise.TensorConstant = TensorConstant
#########################
# Utilities
#########################
def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
"""
Replace a symbol definition with an elementwise version of the
corresponding scalar Op. If it is not None, the nfunc argument
should be a string such that getattr(numpy, nfunc) implements
a vectorized version of the elemwise operation. nin is the number
of inputs expected by that function, and nout is the number of
**destination** inputs it takes. That is, the function should
take nin+nout inputs. nout == 0 means that the numpy function
does not take a numpy array argument to put its result in.
"""
def construct(symbol):
nonlocal symbolname
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(
inplace_scalar_op,
{0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
pprint.assign(
rval, printing.FunctionPrinter(symbolname.replace("_inplace", "="))
)
return rval
if symbol:
return construct(symbol[0])
else:
return construct
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try:
return list(x)
except TypeError:
return [x]
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, (int, np.integer)) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (int, np.integer)) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -1736,6 +1593,59 @@ def makeKeepDims(x, y, axis): ...@@ -1736,6 +1593,59 @@ def makeKeepDims(x, y, axis):
return DimShuffle(y.type.broadcastable, new_dims)(y) return DimShuffle(y.type.broadcastable, new_dims)(y)
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, (int, np.integer)) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (int, np.integer)) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis
@constructor @constructor
def max_and_argmax(a, axis=None, keepdims=False): def max_and_argmax(a, axis=None, keepdims=False):
""" """
......
...@@ -12,7 +12,7 @@ from theano.gof.op import COp, ExternalCOp, OpenMPOp ...@@ -12,7 +12,7 @@ from theano.gof.op import COp, ExternalCOp, OpenMPOp
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.misc.frozendict import frozendict from theano.misc.frozendict import frozendict
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import FunctionPrinter, pprint
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
from theano.tensor import elemwise_cgen as cgen from theano.tensor import elemwise_cgen as cgen
from theano.utils import uniq from theano.utils import uniq
...@@ -2240,3 +2240,55 @@ class ProdWithoutZeros(CAReduceDtype): ...@@ -2240,3 +2240,55 @@ class ProdWithoutZeros(CAReduceDtype):
"`product(a, no_zeros_in_input=True)`.", "`product(a, no_zeros_in_input=True)`.",
) )
return [a_grad] return [a_grad]
def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
"""Replace a symbol definition with an `Elemwise`-wrapped version of the corresponding scalar `Op`.
If it is not ``None``, the `nfunc` argument should be a string such that
``getattr(numpy, nfunc)`` implements a vectorized version of the `Elemwise`
operation. `nin` is the number of inputs expected by that function, and nout
is the number of **destination** inputs it takes. That is, the function
should take nin + nout inputs. `nout == 0` means that the numpy function does
not take a NumPy array argument to put its result in.
"""
def construct(symbol):
nonlocal symbolname
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scalar.transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter(symbolname.replace("_inplace", "=")))
return rval
if symbol:
return construct(symbol[0])
else:
return construct
from theano import printing from theano import printing
from theano.printing import pprint from theano.printing import pprint
from theano.tensor import elemwise from theano.tensor.elemwise import DimShuffle, _scal_elemwise
from theano.tensor.basic import _scal_elemwise
@_scal_elemwise @_scal_elemwise
...@@ -360,4 +359,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right")) ...@@ -360,4 +359,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs): def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage" "Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1)) dims = list(range(x.ndim - 1, -1, -1))
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x) return DimShuffle(x.broadcastable, dims, inplace=True)(x)
...@@ -97,3 +97,13 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -97,3 +97,13 @@ def shape_of_variables(fgraph, input_shapes):
sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var] sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var]
) )
return l return l
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try:
return list(x)
except TypeError:
return [x]
"""Utility functions that only depend on the standard library.""" """Utility functions that only depend on the standard library."""
import hashlib import hashlib
import inspect import inspect
import logging import logging
...@@ -12,7 +11,7 @@ import traceback ...@@ -12,7 +11,7 @@ import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import partial, wraps
__all__ = [ __all__ = [
...@@ -392,3 +391,30 @@ class NoDuplicateOptWarningFilter(logging.Filter): ...@@ -392,3 +391,30 @@ class NoDuplicateOptWarningFilter(logging.Filter):
self.prev_msgs.add(msg) self.prev_msgs.add(msg)
return True return True
return True return True
def _multi(*fns):
"""Create new functions that distributes the wrapped functions across iterable arguments.
For example, a function, `fn`, that uses this decorator satisfies
`fn("hi") == [fn("h"), fn("i")]`.
"""
def f2(f, *names):
if names and isinstance(names[0], int):
if names == 1:
return f()
else:
return [f() for i in range(names[0])]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns[0])
else:
return [partial(f2, f) for f in fns]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论