提交 4807ebf9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move theano.tensor/scalar utility functions to more appropriate modules

上级 deb437d8
......@@ -13,7 +13,6 @@ you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
import math
from collections.abc import Callable
from copy import copy
from functools import partial
from itertools import chain
from textwrap import dedent
......@@ -31,7 +30,7 @@ from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray
from theano.printing import pprint
from theano.utils import difference, from_return_values, to_return_values
from theano.utils import _multi, difference, from_return_values, to_return_values
builtin_bool = bool
......@@ -861,20 +860,6 @@ Scalar.Constant = ScalarConstant
# Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns[0])
else:
return [partial(f2, f) for f in fns]
ints = _multi(int64)
floats = _multi(float64)
complexs = _multi(complex128)
......
......@@ -4,7 +4,6 @@ import builtins
import logging
import warnings
from collections.abc import Sequence
from functools import partial
import numpy as np
......@@ -28,10 +27,12 @@ from theano.scalar import int32
from theano.tensor import elemwise
# set up the external interface
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, _scal_elemwise
from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst
from theano.tensor.utils import _pack
from theano.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
from theano.utils import _multi
_logger = logging.getLogger("theano.tensor.basic")
......@@ -685,27 +686,6 @@ def tensor(*args, **kwargs):
return TensorType(*args, **kwargs)(name=name)
def _multi(*fns):
def f2(f, *names):
if names and isinstance(names[0], int):
if names == 1:
return f()
else:
return [f() for i in range(names[0])]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
cscalar = TensorType("complex64", ())
zscalar = TensorType("complex128", ())
fscalar = TensorType("float32", ())
......@@ -1037,129 +1017,6 @@ elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
#########################
# Utilities
#########################
def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
"""
Replace a symbol definition with an elementwise version of the
corresponding scalar Op. If it is not None, the nfunc argument
should be a string such that getattr(numpy, nfunc) implements
a vectorized version of the elemwise operation. nin is the number
of inputs expected by that function, and nout is the number of
**destination** inputs it takes. That is, the function should
take nin+nout inputs. nout == 0 means that the numpy function
does not take a numpy array argument to put its result in.
"""
def construct(symbol):
nonlocal symbolname
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(
inplace_scalar_op,
{0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
pprint.assign(
rval, printing.FunctionPrinter(symbolname.replace("_inplace", "="))
)
return rval
if symbol:
return construct(symbol[0])
else:
return construct
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try:
return list(x)
except TypeError:
return [x]
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, (int, np.integer)) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (int, np.integer)) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis
#########################
# Casting Operations
#########################
......@@ -1736,6 +1593,59 @@ def makeKeepDims(x, y, axis):
return DimShuffle(y.type.broadcastable, new_dims)(y)
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, (int, np.integer)) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (int, np.integer)) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis
@constructor
def max_and_argmax(a, axis=None, keepdims=False):
"""
......
......@@ -12,7 +12,7 @@ from theano.gof.op import COp, ExternalCOp, OpenMPOp
from theano.gradient import DisconnectedType
from theano.misc.frozendict import frozendict
from theano.misc.safe_asarray import _asarray
from theano.printing import pprint
from theano.printing import FunctionPrinter, pprint
from theano.scalar import get_scalar_type
from theano.tensor import elemwise_cgen as cgen
from theano.utils import uniq
......@@ -2240,3 +2240,55 @@ class ProdWithoutZeros(CAReduceDtype):
"`product(a, no_zeros_in_input=True)`.",
)
return [a_grad]
def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
"""Replace a symbol definition with an `Elemwise`-wrapped version of the corresponding scalar `Op`.
If it is not ``None``, the `nfunc` argument should be a string such that
``getattr(numpy, nfunc)`` implements a vectorized version of the `Elemwise`
operation. `nin` is the number of inputs expected by that function, and nout
is the number of **destination** inputs it takes. That is, the function
should take nin + nout inputs. `nout == 0` means that the numpy function does
not take a NumPy array argument to put its result in.
"""
def construct(symbol):
nonlocal symbolname
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scalar.transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter(symbolname.replace("_inplace", "=")))
return rval
if symbol:
return construct(symbol[0])
else:
return construct
from theano import printing
from theano.printing import pprint
from theano.tensor import elemwise
from theano.tensor.basic import _scal_elemwise
from theano.tensor.elemwise import DimShuffle, _scal_elemwise
@_scal_elemwise
......@@ -360,4 +359,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return elemwise.DimShuffle(x.broadcastable, dims, inplace=True)(x)
return DimShuffle(x.broadcastable, dims, inplace=True)(x)
......@@ -97,3 +97,13 @@ def shape_of_variables(fgraph, input_shapes):
sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var]
)
return l
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try:
return list(x)
except TypeError:
return [x]
"""Utility functions that only depend on the standard library."""
import hashlib
import inspect
import logging
......@@ -12,7 +11,7 @@ import traceback
import warnings
from collections import OrderedDict
from collections.abc import Callable
from functools import wraps
from functools import partial, wraps
__all__ = [
......@@ -392,3 +391,30 @@ class NoDuplicateOptWarningFilter(logging.Filter):
self.prev_msgs.add(msg)
return True
return True
def _multi(*fns):
"""Create new functions that distributes the wrapped functions across iterable arguments.
For example, a function, `fn`, that uses this decorator satisfies
`fn("hi") == [fn("h"), fn("i")]`.
"""
def f2(f, *names):
if names and isinstance(names[0], int):
if names == 1:
return f()
else:
return [f() for i in range(names[0])]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns[0])
else:
return [partial(f2, f) for f in fns]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论