Unverified 提交 430d068d authored 作者: Rémi Louf's avatar Rémi Louf 提交者: GitHub

Rename `abs_` to `abs` (#483)

Rename aesara.tensor.math.abs_ to abs
上级 c60579c8
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import aesara import aesara
from aesara.compile.mode import Mode, get_mode from aesara.compile.mode import Mode, get_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.tensor.math import abs_ from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import min as aet_min from aesara.tensor.math import min as aet_min
from aesara.tensor.type import discrete_dtypes from aesara.tensor.type import discrete_dtypes
...@@ -174,7 +174,7 @@ def f_compute(op): ...@@ -174,7 +174,7 @@ def f_compute(op):
f_gpua_min = f_compute(aet_min) f_gpua_min = f_compute(aet_min)
f_gpua_max = f_compute(aet_max) f_gpua_max = f_compute(aet_max)
f_gpua_absmax = f_compute(lambda x: aet_max(abs_(x))) f_gpua_absmax = f_compute(lambda x: aet_max(aet_abs(x)))
class NanGuardMode(Mode): class NanGuardMode(Mode):
......
...@@ -10,6 +10,7 @@ If you want to use a scalar variable in an Aesara graph, ...@@ -10,6 +10,7 @@ If you want to use a scalar variable in an Aesara graph,
you probably want to use aesara.tensor.[c,z,f,d,b,w,i,l,]scalar! you probably want to use aesara.tensor.[c,z,f,d,b,w,i,l,]scalar!
""" """
import builtins
import math import math
from collections.abc import Callable from collections.abc import Callable
from copy import copy from copy import copy
...@@ -45,6 +46,10 @@ builtin_int = int ...@@ -45,6 +46,10 @@ builtin_int = int
builtin_float = float builtin_float = float
# We capture the builtins that we are going to replace to follow the numpy API
_abs = builtins.abs
class ComplexError(NotImplementedError): class ComplexError(NotImplementedError):
""" """
Raised if complex numbers are used in an unsupported operation. Raised if complex numbers are used in an unsupported operation.
...@@ -383,7 +388,7 @@ class Scalar(CType): ...@@ -383,7 +388,7 @@ class Scalar(CType):
diff = a - b diff = a - b
if diff == 0: if diff == 0:
return True return True
return abs(diff) <= (abs(a) * tolerance) + (abs(b) * tolerance) return _abs(diff) <= (_abs(a) * tolerance) + (_abs(b) * tolerance)
def c_element_type(self): def c_element_type(self):
return self.dtype_specs()[1] return self.dtype_specs()[1]
...@@ -757,7 +762,7 @@ class _scalar_py_operators: ...@@ -757,7 +762,7 @@ class _scalar_py_operators:
# UNARY # UNARY
def __abs__(self): def __abs__(self):
return abs_(self) return abs(self)
def __neg__(self): def __neg__(self):
return neg(self) return neg(self)
...@@ -2543,7 +2548,7 @@ class Abs(UnaryScalarOp): ...@@ -2543,7 +2548,7 @@ class Abs(UnaryScalarOp):
if x.type in float_types: if x.type in float_types:
return (gz * sgn(x),) return (gz * sgn(x),)
return (gz * x / abs(x),) # formula works for complex and real return (gz * x / _abs(x),) # formula works for complex and real
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -2563,7 +2568,7 @@ class Abs(UnaryScalarOp): ...@@ -2563,7 +2568,7 @@ class Abs(UnaryScalarOp):
raise NotImplementedError("type not supported", type) raise NotImplementedError("type not supported", type)
abs_ = Abs(same_out) abs = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
...@@ -3862,7 +3867,7 @@ class Angle(UnaryScalarOp): ...@@ -3862,7 +3867,7 @@ class Angle(UnaryScalarOp):
(gtheta,) = gout (gtheta,) = gout
x = real(c) x = real(c)
y = imag(c) y = imag(c)
r = abs(c) r = _abs(c)
gr = -gtheta * y / (r ** 2 * sqrt(1 - (y / r) ** 2)) gr = -gtheta * y / (r ** 2 * sqrt(1 - (y / r) ** 2))
gx = gr * x / r gx = gr * x / r
......
...@@ -20,7 +20,7 @@ from aesara.scalar import int32 as int_t ...@@ -20,7 +20,7 @@ from aesara.scalar import int32 as int_t
from aesara.scalar import upcast from aesara.scalar import upcast
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs_ from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import eq, ge, lt, maximum, minimum, or_, prod from aesara.tensor.math import eq, ge, lt, maximum, minimum, or_, prod
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -1084,7 +1084,7 @@ class FillDiagonalOffset(Op): ...@@ -1084,7 +1084,7 @@ class FillDiagonalOffset(Op):
# only valid for matrices # only valid for matrices
wr_a = fill_diagonal_offset(grad, 0, offset) wr_a = fill_diagonal_offset(grad, 0, offset)
offset_abs = abs_(offset) offset_abs = aet_abs(offset)
pos_offset_flag = ge(offset, 0) pos_offset_flag = ge(offset, 0)
neg_offset_flag = lt(offset, 0) neg_offset_flag = lt(offset, 0)
min_wh = minimum(width, height) min_wh = minimum(width, height)
......
...@@ -54,7 +54,7 @@ def invert_inplace(a): ...@@ -54,7 +54,7 @@ def invert_inplace(a):
@scalar_elemwise @scalar_elemwise
def abs__inplace(a): def abs_inplace(a):
"""|`a`| (inplace on `a`)""" """|`a`| (inplace on `a`)"""
......
import builtins
import warnings import warnings
import numpy as np import numpy as np
...@@ -45,6 +46,10 @@ from aesara.tensor.utils import as_list ...@@ -45,6 +46,10 @@ from aesara.tensor.utils import as_list
from aesara.tensor.var import TensorConstant, _tensor_py_operators from aesara.tensor.var import TensorConstant, _tensor_py_operators
# We capture the builtins that we are going to replace to follow the numpy API
_abs = builtins.abs
if int(config.tensor__cmp_sloppy) > 1: if int(config.tensor__cmp_sloppy) > 1:
# This config variable is a quick-and-dirty way to get low-precision # This config variable is a quick-and-dirty way to get low-precision
# comparisons. For a more precise setting of these tolerances set # comparisons. For a more precise setting of these tolerances set
...@@ -955,8 +960,8 @@ def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): ...@@ -955,8 +960,8 @@ def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
""" """
# close will be an int8 array of 1 where within tolerance # close will be an int8 array of 1 where within tolerance
# and 0 where not within tolerance or there was a nan or inf value. # and 0 where not within tolerance or there was a nan or inf value.
diff = abs(a - b) diff = _abs(a - b)
tolerance = atol + rtol * abs(b) tolerance = atol + rtol * _abs(b)
close_prelim = le(diff, tolerance) close_prelim = le(diff, tolerance)
a_nan = isnan(a) a_nan = isnan(a)
...@@ -1033,16 +1038,15 @@ bitwise_not = invert # numpy alias for it ...@@ -1033,16 +1038,15 @@ bitwise_not = invert # numpy alias for it
@scalar_elemwise @scalar_elemwise
def abs_(a): def abs(a):
"""|`a`| """|`a`|"""
TensorVariable overloads the `TensorVariable.__abs__` operator so that
this function is called when you type abs(a).
""" # These are deprecated and will be removed
abs_ = abs
pprint.assign(abs_, printing.PatternPrinter(("|%(0)s|", -1000))) pprint.assign(abs, printing.PatternPrinter(("|%(0)s|", -1000)))
@scalar_elemwise @scalar_elemwise
...@@ -2786,6 +2790,7 @@ __all__ = [ ...@@ -2786,6 +2790,7 @@ __all__ = [
"bitwise_xor", "bitwise_xor",
"invert", "invert",
"bitwise_not", "bitwise_not",
"abs",
"abs_", "abs_",
"exp", "exp",
"exp2", "exp2",
......
...@@ -63,7 +63,9 @@ from aesara.tensor.math import ( ...@@ -63,7 +63,9 @@ from aesara.tensor.math import (
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
abs_, )
from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import (
add, add,
dot, dot,
eq, eq,
...@@ -2177,7 +2179,7 @@ def check_for_x_over_absX(numerators, denominators): ...@@ -2177,7 +2179,7 @@ def check_for_x_over_absX(numerators, denominators):
# TODO: this function should dig/search through dimshuffles # TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value # This won't catch a dimshuffled absolute value
for den in list(denominators): for den in list(denominators):
if den.owner and den.owner.op == abs_ and den.owner.inputs[0] in numerators: if den.owner and den.owner.op == aet_abs and den.owner.inputs[0] in numerators:
if den.owner.inputs[0].type.dtype.startswith("complex"): if den.owner.inputs[0].type.dtype.startswith("complex"):
# TODO: Make an Op that projects a complex number to # TODO: Make an Op that projects a complex number to
# have unit length but projects 0 to 0. That # have unit length but projects 0 to 0. That
...@@ -2197,7 +2199,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX") ...@@ -2197,7 +2199,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX")
@register_canonicalize @register_canonicalize
@local_optimizer([abs_]) @local_optimizer([aet_abs])
def local_abs_lift(fgraph, node): def local_abs_lift(fgraph, node):
""" """
Move the abs toward the input. Move the abs toward the input.
...@@ -2205,13 +2207,13 @@ def local_abs_lift(fgraph, node): ...@@ -2205,13 +2207,13 @@ def local_abs_lift(fgraph, node):
This is needed for check_for_x_over_absX to apply in more case. This is needed for check_for_x_over_absX to apply in more case.
""" """
if node.op == abs_ and node.inputs[0].owner: if node.op == aet_abs and node.inputs[0].owner:
assert node.nin == 1 assert node.nin == 1
if node.inputs[0].owner.op == mul: if node.inputs[0].owner.op == mul:
return [mul(*[abs_(i) for i in node.inputs[0].owner.inputs])] return [mul(*[aet_abs(i) for i in node.inputs[0].owner.inputs])]
if node.inputs[0].owner.op == true_div: if node.inputs[0].owner.op == true_div:
i = node.inputs[0].owner.inputs i = node.inputs[0].owner.inputs
return [true_div(abs_(i[0]), abs_(i[1]))] return [true_div(aet_abs(i[0]), aet_abs(i[1]))]
@register_specialize @register_specialize
...@@ -2222,10 +2224,13 @@ def local_abs_merge(fgraph, node): ...@@ -2222,10 +2224,13 @@ def local_abs_merge(fgraph, node):
need it anymore need it anymore
""" """
if node.op == mul and sum([i.owner.op == abs_ for i in node.inputs if i.owner]) > 1: if (
node.op == mul
and sum([i.owner.op == aet_abs for i in node.inputs if i.owner]) > 1
):
inputs = [] inputs = []
for i in node.inputs: for i in node.inputs:
if i.owner and i.owner.op == abs_: if i.owner and i.owner.op == aet_abs:
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
elif isinstance(i, Constant): elif isinstance(i, Constant):
try: try:
...@@ -2237,13 +2242,13 @@ def local_abs_merge(fgraph, node): ...@@ -2237,13 +2242,13 @@ def local_abs_merge(fgraph, node):
inputs.append(i) inputs.append(i)
else: else:
return False return False
return [abs_(mul(*inputs))] return [aet_abs(mul(*inputs))]
if ( if (
node.op == true_div node.op == true_div
and sum([i.owner.op == abs_ for i in node.inputs if i.owner]) == 2 and sum([i.owner.op == aet_abs for i in node.inputs if i.owner]) == 2
): ):
return [ return [
abs_( aet_abs(
true_div(node.inputs[0].owner.inputs[0], node.inputs[1].owner.inputs[0]) true_div(node.inputs[0].owner.inputs[0], node.inputs[1].owner.inputs[0])
) )
] ]
......
...@@ -16,7 +16,7 @@ from aesara.tensor.utils import hash_from_ndarray ...@@ -16,7 +16,7 @@ from aesara.tensor.utils import hash_from_ndarray
class _tensor_py_operators: class _tensor_py_operators:
def __abs__(self): def __abs__(self):
return aet.math.abs_(self) return aet.math.abs(self)
def __neg__(self): def __neg__(self):
return aet.math.neg(self) return aet.math.neg(self)
...@@ -695,7 +695,7 @@ class _tensor_py_operators: ...@@ -695,7 +695,7 @@ class _tensor_py_operators:
raise NotImplementedError() raise NotImplementedError()
# optimizations will/should catch cases like L=1, L=2 # optimizations will/should catch cases like L=1, L=2
y = aet.math.pow( y = aet.math.pow(
aet.math.pow(aet.math.abs_(self), L).sum(axis=axis), aet.math.pow(aet.math.abs(self), L).sum(axis=axis),
1.0 / L, 1.0 / L,
) )
if keepdims: if keepdims:
......
...@@ -1421,19 +1421,19 @@ Here is an example using the bit-wise ``and_`` via the ``&`` operator: ...@@ -1421,19 +1421,19 @@ Here is an example using the bit-wise ``and_`` via the ``&`` operator:
Mathematical Mathematical
------------ ------------
.. function:: abs_(a) .. function:: abs(a)
Returns a variable representing the absolute of a, ie ``|a|``. Returns a variable representing the absolute of ``a``, i.e. ``|a|``.
.. note:: Can also be accessed with ``abs(a)``. .. note:: Can also be accessed using `builtins.abs`: i.e. ``abs(a)``.
.. function:: angle(a) .. function:: angle(a)
Returns a variable representing angular component of complex-valued Tensor `a`. Returns a variable representing angular component of complex-valued Tensor ``a``.
.. function:: exp(a) .. function:: exp(a)
Returns a variable representing the exponential of a, ie e^a. Returns a variable representing the exponential of ``a``.
.. function:: maximum(a, b) .. function:: maximum(a, b)
...@@ -1445,7 +1445,7 @@ Mathematical ...@@ -1445,7 +1445,7 @@ Mathematical
.. function:: neg(a) .. function:: neg(a)
Returns a variable representing the negation of `a` (also ``-a``). Returns a variable representing the negation of ``a`` (also ``-a``).
.. function:: reciprocal(a) .. function:: reciprocal(a)
......
...@@ -5,7 +5,7 @@ from aesara import config ...@@ -5,7 +5,7 @@ from aesara import config
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import round_half_away_from_zero_vec, upcast from aesara.scalar.basic import round_half_away_from_zero_vec, upcast
from aesara.tensor.inplace import ( from aesara.tensor.inplace import (
abs__inplace, abs_inplace,
add_inplace, add_inplace,
arccos_inplace, arccos_inplace,
arccosh_inplace, arccosh_inplace,
...@@ -184,7 +184,7 @@ TestSgnInplaceBroadcast = makeBroadcastTester( ...@@ -184,7 +184,7 @@ TestSgnInplaceBroadcast = makeBroadcastTester(
) )
TestAbsInplaceBroadcast = makeBroadcastTester( TestAbsInplaceBroadcast = makeBroadcastTester(
op=abs__inplace, op=abs_inplace,
expected=lambda x: np.abs(x), expected=lambda x: np.abs(x),
good=_good_broadcast_unary_normal_abs, good=_good_broadcast_unary_normal_abs,
inplace=True, inplace=True,
......
...@@ -39,7 +39,7 @@ from aesara.tensor.math import ( ...@@ -39,7 +39,7 @@ from aesara.tensor.math import (
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
_dot, _dot,
abs_, abs,
add, add,
allclose, allclose,
arccos, arccos,
...@@ -345,8 +345,8 @@ TestPowBroadcast = makeBroadcastTester( ...@@ -345,8 +345,8 @@ TestPowBroadcast = makeBroadcastTester(
) )
TestAbsBroadcast = makeBroadcastTester( TestAbsBroadcast = makeBroadcastTester(
op=abs_, op=abs,
expected=lambda x: abs(x), expected=lambda x: np.abs(x),
good=_good_broadcast_unary_normal, good=_good_broadcast_unary_normal,
grad=_grad_broadcast_unary_normal, grad=_grad_broadcast_unary_normal,
) )
...@@ -640,19 +640,19 @@ TestArctanhBroadcast = makeBroadcastTester( ...@@ -640,19 +640,19 @@ TestArctanhBroadcast = makeBroadcastTester(
# Complex operations # Complex operations
_good_complex_from_polar = dict( _good_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)), same_shapes=(np.abs(rand(2, 3)), rand(2, 3)),
not_same_dimensions=(abs(rand(2, 2)), rand(2)), not_same_dimensions=(np.abs(rand(2, 2)), rand(2)),
scalar=(abs(rand(2, 3)), rand(1, 1)), scalar=(np.abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)), row=(np.abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)), column=(np.abs(rand(2, 3)), rand(2, 1)),
integers=(abs(randint(2, 3)), randint(2, 3)), integers=(np.abs(randint(2, 3)), randint(2, 3)),
empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)), empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)),
) )
_grad_complex_from_polar = dict( _grad_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)), same_shapes=(np.abs(rand(2, 3)), rand(2, 3)),
scalar=(abs(rand(2, 3)), rand(1, 1)), scalar=(np.abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)), row=(np.abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)), column=(np.abs(rand(2, 3)), rand(2, 1)),
) )
TestComplexFromPolarBroadcast = makeBroadcastTester( TestComplexFromPolarBroadcast = makeBroadcastTester(
......
...@@ -28,7 +28,9 @@ from aesara.tensor.basic_opt import local_dimshuffle_lift ...@@ -28,7 +28,9 @@ from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, abs_, add from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum
from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import add
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import any as aet_any from aesara.tensor.math import any as aet_any
from aesara.tensor.math import ( from aesara.tensor.math import (
...@@ -848,7 +850,7 @@ class TestAlgebraicCanonize: ...@@ -848,7 +850,7 @@ class TestAlgebraicCanonize:
# 4 * x / abs(2*x) it get simplifier during canonicalisation. # 4 * x / abs(2*x) it get simplifier during canonicalisation.
x = dscalar() x = dscalar()
# a = aet.abs_(x) # a = aet.aet_abs(x)
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
mode = get_mode("FAST_RUN").excluding("local_elemwise_fusion") mode = get_mode("FAST_RUN").excluding("local_elemwise_fusion")
...@@ -981,7 +983,7 @@ def test_merge_abs_bugfix(): ...@@ -981,7 +983,7 @@ def test_merge_abs_bugfix():
# normalize on rows # normalize on rows
step2 = step1 / step1.sum(1) step2 = step1 / step1.sum(1)
# get l1 norm # get l1 norm
l1_norm = abs_(step2).sum() l1_norm = aet_abs(step2).sum()
function([input], aesara.gradient.grad(l1_norm, input)) function([input], aesara.gradient.grad(l1_norm, input))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论