提交 bfabe261 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename inv to reciprocal

上级 9ac6149c
...@@ -1543,7 +1543,7 @@ class ProfileStats: ...@@ -1543,7 +1543,7 @@ class ProfileStats:
aes.Cast, aes.Cast,
aes.Sgn, aes.Sgn,
aes.Neg, aes.Neg,
aes.Inv, aes.Reciprocal,
aes.Sqr, aes.Sqr,
] ]
scalar_op_amdlibm_speed_up = [ scalar_op_amdlibm_speed_up = [
......
...@@ -3385,7 +3385,7 @@ def dnn_batch_normalization_train( ...@@ -3385,7 +3385,7 @@ def dnn_batch_normalization_train(
axes = 0 if mode == 'per-activation' else (0, 2, 3) axes = 0 if mode == 'per-activation' else (0, 2, 3)
mean = inputs.mean(axes, keepdims=True) mean = inputs.mean(axes, keepdims=True)
var = inputs.var(axes, keepdims=True) var = inputs.var(axes, keepdims=True)
invstd = aet.inv(aet.sqrt(var + epsilon)) invstd = aet.reciprocal(aet.sqrt(var + epsilon))
out = (inputs - mean) * gamma * invstd + beta out = (inputs - mean) * gamma * invstd + beta
m = aet.cast(aet.prod(inputs.shape) / aet.prod(mean.shape), 'float32') m = aet.cast(aet.prod(inputs.shape) / aet.prod(mean.shape), 'float32')
......
...@@ -2846,11 +2846,8 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right")) ...@@ -2846,11 +2846,8 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
pprint.assign(mod, printing.OperatorPrinter("%", -1, "left")) pprint.assign(mod, printing.OperatorPrinter("%", -1, "left"))
class Inv(UnaryScalarOp): class Reciprocal(UnaryScalarOp):
""" """Multiplicative inverse."""
Multiplicative inverse. Also called reciprocal.
"""
nfunc_spec = ("reciprocal", 1, 1) nfunc_spec = ("reciprocal", 1, 1)
...@@ -2878,7 +2875,11 @@ class Inv(UnaryScalarOp): ...@@ -2878,7 +2875,11 @@ class Inv(UnaryScalarOp):
return f"{z} = 1.0 / {x};" return f"{z} = 1.0 / {x};"
inv = Inv(upgrade_to_float, name="inv") reciprocal = Reciprocal(upgrade_to_float, name="reciprocal")
# These are deprecated and will be removed
Inv = Reciprocal
inv = reciprocal
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
......
...@@ -79,7 +79,7 @@ def neg_inplace(a): ...@@ -79,7 +79,7 @@ def neg_inplace(a):
@scalar_elemwise @scalar_elemwise
def inv_inplace(a): def reciprocal_inplace(a):
"""1.0/a (inplace on a)""" """1.0/a (inplace on a)"""
......
...@@ -1065,13 +1065,15 @@ def neg(a): ...@@ -1065,13 +1065,15 @@ def neg(a):
"""-a""" """-a"""
# numpy.reciprocal does integer division on integer inputs
# (which is not very interesting)
@scalar_elemwise @scalar_elemwise
def inv(a): def reciprocal(a):
"""1.0/a""" """1.0/a"""
# This is deprecated and will be removed
inv = reciprocal
@scalar_elemwise @scalar_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
...@@ -2789,6 +2791,7 @@ __all__ = [ ...@@ -2789,6 +2791,7 @@ __all__ = [
"exp2", "exp2",
"expm1", "expm1",
"neg", "neg",
"reciprocal",
"inv", "inv",
"log", "log",
"log2", "log2",
......
...@@ -74,7 +74,6 @@ from aesara.tensor.math import ( ...@@ -74,7 +74,6 @@ from aesara.tensor.math import (
expm1, expm1,
ge, ge,
int_div, int_div,
inv,
log, log,
log1p, log1p,
makeKeepDims, makeKeepDims,
...@@ -82,7 +81,7 @@ from aesara.tensor.math import ( ...@@ -82,7 +81,7 @@ from aesara.tensor.math import (
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, mul, neg from aesara.tensor.math import maximum, mul, neg
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, sgn, sigmoid, softplus, sqr, sqrt, sub from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import true_div from aesara.tensor.math import true_div
from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.shape import Shape, Shape_i
...@@ -206,7 +205,7 @@ def local_func_inv(fgraph, node): ...@@ -206,7 +205,7 @@ def local_func_inv(fgraph, node):
(aes.Sinh, aes.ArcSinh), (aes.Sinh, aes.ArcSinh),
(aes.Conj, aes.Conj), (aes.Conj, aes.Conj),
(aes.Neg, aes.Neg), (aes.Neg, aes.Neg),
(aes.Inv, aes.Inv), (aes.Reciprocal, aes.Reciprocal),
) )
x = node.inputs[0] x = node.inputs[0]
...@@ -511,28 +510,29 @@ def local_div_switch_sink(fgraph, node): ...@@ -511,28 +510,29 @@ def local_div_switch_sink(fgraph, node):
class AlgebraicCanonizer(LocalOptimizer): class AlgebraicCanonizer(LocalOptimizer):
r""" r"""Simplification tool.
Simplification tool. The variable is a local_optimizer. It is best used
with a TopoOptimizer in in_to_out order.
Usage: AlgebraicCanonizer(main, inverse, reciprocal, calculate) The variable is a ``local_optimizer``. It is best used
with a ``TopoOptimizer`` in ``in_to_out`` order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
Parameters Parameters
---------- ----------
main main
A suitable Op class that is commutative, associative and A suitable ``Op`` class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or takes one to an arbitrary number of inputs, e.g. add or
mul mul
inverse inverse
An Op class such that inverse(main(x, y), y) == x An ``Op`` class such that ``inverse(main(x, y), y) == x``
e.g. sub or true_div e.g. ``sub`` or true_div
reciprocal reciprocal
A function such that main(x, reciprocal(y)) == inverse(x, y) A function such that ``main(x, reciprocal(y)) == inverse(x, y)``
e.g. neg or inv e.g. ``neg`` or ``reciprocal``
calculate calculate
Function that takes a list of numpy.ndarray instances Function that takes a list of numpy.ndarray instances
for the numerator, another list for the denumerator, for the numerator, another list for the denumerator,
and calculates inverse(main(\*num), main(\*denum)). It and calculates ``inverse(main(\*num), main(\*denum))``. It
takes a keyword argument, aslist. If True, the value takes a keyword argument, aslist. If True, the value
should be returned as a list of one element, unless should be returned as a list of one element, unless
the value is such that value = main(). In that case, the value is such that value = main(). In that case,
...@@ -547,7 +547,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -547,7 +547,7 @@ class AlgebraicCanonizer(LocalOptimizer):
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\ >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
... lambda n, d: prod(n) / prod(d)) ... lambda n, d: prod(n) / prod(d))
Examples of optimizations mul_canonizer can perform: Examples of optimizations ``mul_canonizer`` can perform:
| x / x -> 1 | x / x -> 1
| (x * y) / x -> y | (x * y) / x -> y
...@@ -562,10 +562,10 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -562,10 +562,10 @@ class AlgebraicCanonizer(LocalOptimizer):
""" """
def __init__(self, main, inverse, reciprocal, calculate, use_reciprocal=True): def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True):
self.main = main self.main = main
self.inverse = inverse self.inverse = inverse_fn
self.reciprocal = reciprocal self.reciprocal = reciprocal_fn
self.calculate = calculate self.calculate = calculate
self.use_reciprocal = use_reciprocal self.use_reciprocal = use_reciprocal
...@@ -579,11 +579,11 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -579,11 +579,11 @@ class AlgebraicCanonizer(LocalOptimizer):
def get_num_denum(self, input): def get_num_denum(self, input):
r""" r"""
This extract two lists, num and denum, such that the input is: This extract two lists, ``num`` and ``denum``, such that the input is:
self.inverse(self.main(\*num), self.main(\*denum)). It returns ``self.inverse(self.main(\*num), self.main(\*denum))``. It returns
the two lists in a (num, denum) pair. the two lists in a ``(num, denum)`` pair.
For example, for main, inverse and reciprocal = \*, / and inv(), For example, for main, inverse and ``reciprocal = \*, / and inv()``,
| input -> returned value (num, denum) | input -> returned value (num, denum)
...@@ -1013,7 +1013,9 @@ def mul_calculate(num, denum, aslist=False, out_type=None): ...@@ -1013,7 +1013,9 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
return v return v
local_mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, mul_calculate, False) local_mul_canonizer = AlgebraicCanonizer(
mul, true_div, reciprocal, mul_calculate, False
)
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
...@@ -1847,12 +1849,12 @@ register_canonicalize(local_mul_zero) ...@@ -1847,12 +1849,12 @@ register_canonicalize(local_mul_zero)
@local_optimizer([true_div]) @local_optimizer([true_div])
def local_div_to_inv(fgraph, node): def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all( if node.op == true_div and np.all(
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0 local_mul_canonizer.get_constant(node.inputs[0]) == 1.0
): ):
out = node.outputs[0] out = node.outputs[0]
new_out = inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting # The ones could have forced upcasting
if new_out.dtype != out.dtype: if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype) new_out = cast(new_out, dtype=out.dtype)
...@@ -1864,18 +1866,19 @@ def local_div_to_inv(fgraph, node): ...@@ -1864,18 +1866,19 @@ def local_div_to_inv(fgraph, node):
return False return False
register_specialize(local_div_to_inv) # TODO: Add this to the canonicalization to reduce redundancy.
register_specialize(local_div_to_reciprocal)
@local_optimizer([inv]) @local_optimizer([reciprocal])
def local_inv_canon(fgraph, node): def local_reciprocal_canon(fgraph, node):
if node.op == inv: if node.op == reciprocal:
return [aet_pow(node.inputs[0], -1.0)] return [aet_pow(node.inputs[0], -1.0)]
else: else:
return False return False
register_canonicalize(local_inv_canon) register_canonicalize(local_reciprocal_canon)
@local_optimizer([aet_pow]) @local_optimizer([aet_pow])
...@@ -1958,11 +1961,11 @@ def local_pow_specialize(fgraph, node): ...@@ -1958,11 +1961,11 @@ def local_pow_specialize(fgraph, node):
if np.all(y == 0.5): if np.all(y == 0.5):
rval = [sqrt(xsym)] rval = [sqrt(xsym)]
if np.all(y == -0.5): if np.all(y == -0.5):
rval = [inv(sqrt(xsym))] rval = [reciprocal(sqrt(xsym))]
if np.all(y == -1): if np.all(y == -1):
rval = [inv(xsym)] rval = [reciprocal(xsym)]
if np.all(y == -2): if np.all(y == -2):
rval = [inv(sqr(xsym))] rval = [reciprocal(sqr(xsym))]
if rval: if rval:
rval[0] = cast(rval[0], odtype) rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs) assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
...@@ -2032,7 +2035,7 @@ def local_pow_specialize_device(fgraph, node): ...@@ -2032,7 +2035,7 @@ def local_pow_specialize_device(fgraph, node):
aes.Composite([pow2_scal[0]], [rval1_scal]) aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym) ).make_node(xsym)
if y < 0: if y < 0:
rval = [inv(rval1)] rval = [reciprocal(rval1)]
else: else:
rval = [rval1] rval = [rval1]
if rval: if rval:
...@@ -2476,7 +2479,7 @@ def attempt_distribution(factor, num, denum, out_type): ...@@ -2476,7 +2479,7 @@ def attempt_distribution(factor, num, denum, out_type):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([mul, true_div, inv]) @local_optimizer([mul, true_div, reciprocal])
def local_greedy_distributor(fgraph, node): def local_greedy_distributor(fgraph, node):
""" """
Optimize by reducing the number of multiplications and/or divisions. Optimize by reducing the number of multiplications and/or divisions.
...@@ -3584,19 +3587,21 @@ def local_sigm_times_exp(fgraph, node): ...@@ -3584,19 +3587,21 @@ def local_sigm_times_exp(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer([inv]) @local_optimizer([reciprocal])
def local_inv_1_plus_exp(fgraph, node): def local_reciprocal_1_plus_exp(fgraph, node):
""" """``reciprocal(1+exp(x)) -> sigm(-x)``
1/(1+exp(x)) -> sigm(-x)
TODO: This is redundant; we can just decided on *one* canonical form
for division (e.g. either `true_div` or `reciprocal`) and have this
taken care of with existing rewrites.
""" """
# this optimization should be done for numerical stability # this optimization should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
if node.op == inv: if node.op == reciprocal:
inv_arg = node.inputs[0] reciprocal_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == add: if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest( scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
inv_arg.owner.inputs, only_process_constants=True reciprocal_arg.owner.inputs, only_process_constants=True
) )
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
...@@ -3608,9 +3613,11 @@ def local_inv_1_plus_exp(fgraph, node): ...@@ -3608,9 +3613,11 @@ def local_inv_1_plus_exp(fgraph, node):
) )
# keep combined stack traces of # keep combined stack traces of
# exp(x): nonconsts[0], # exp(x): nonconsts[0],
# 1 + exp(x): inv_arg, # 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0] # 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace([nonconsts[0], inv_arg, node.outputs[0]], out) copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
)
return out return out
......
...@@ -10,7 +10,7 @@ from aesara.tensor import basic as aet ...@@ -10,7 +10,7 @@ from aesara.tensor import basic as aet
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_specialize_device from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import inv, mean, prod, sqrt from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
...@@ -185,7 +185,7 @@ def batch_normalization_train( ...@@ -185,7 +185,7 @@ def batch_normalization_train(
axes = (0,) + tuple(range(2, inputs.ndim)) axes = (0,) + tuple(range(2, inputs.ndim))
mean = inputs.mean(axes, keepdims=True) mean = inputs.mean(axes, keepdims=True)
var = inputs.var(axes, keepdims=True) var = inputs.var(axes, keepdims=True)
invstd = aet.inv(aet.sqrt(var + epsilon)) invstd = aet.reciprocal(aet.sqrt(var + epsilon))
out = (inputs - mean) * gamma * invstd + beta out = (inputs - mean) * gamma * invstd + beta
m = aet.cast(ate.prod(inputs.shape) / aet.prod(mean.shape), 'float32') m = aet.cast(ate.prod(inputs.shape) / aet.prod(mean.shape), 'float32')
...@@ -802,7 +802,7 @@ def local_abstract_batch_norm_train(fgraph, node): ...@@ -802,7 +802,7 @@ def local_abstract_batch_norm_train(fgraph, node):
# The epsilon should not upcast the dtype. # The epsilon should not upcast the dtype.
if var.dtype == "float32" and epsilon.dtype == "float64": if var.dtype == "float32" and epsilon.dtype == "float64":
epsilon = epsilon.astype("float32") epsilon = epsilon.astype("float32")
invstd = inv(sqrt(var + epsilon)) invstd = reciprocal(sqrt(var + epsilon))
out = (x - mean) * (scale * invstd) + bias out = (x - mean) * (scale * invstd) + bias
results = [out, mean, invstd] results = [out, mean, invstd]
......
...@@ -1435,7 +1435,7 @@ Mathematical ...@@ -1435,7 +1435,7 @@ Mathematical
Returns a variable representing the negation of `a` (also ``-a``). Returns a variable representing the negation of `a` (also ``-a``).
.. function:: inv(a) .. function:: reciprocal(a)
Returns a variable representing the inverse of a, ie 1.0/a. Also called reciprocal. Returns a variable representing the inverse of a, ie 1.0/a. Also called reciprocal.
......
...@@ -205,7 +205,7 @@ Optimization o4 o3 o2 ...@@ -205,7 +205,7 @@ Optimization o4 o3 o2
recognize them. Some examples include: recognize them. Some examples include:
* ``pow(x,2)`` -> ``x**2`` * ``pow(x,2)`` -> ``x**2``
* ``pow(x,0)`` -> ``ones_like(x)`` * ``pow(x,0)`` -> ``ones_like(x)``
* ``pow(x, -0.5)`` -> ``inv(sqrt(x))`` * ``pow(x, -0.5)`` -> ``reciprocal(sqrt(x))``
See :func:`local_pow_specialize` See :func:`local_pow_specialize`
......
...@@ -22,13 +22,13 @@ from aesara.tensor.math import ( ...@@ -22,13 +22,13 @@ from aesara.tensor.math import (
clip, clip,
dot, dot,
floor, floor,
inv,
log, log,
max_and_argmax, max_and_argmax,
mean, mean,
minimum, minimum,
mod, mod,
prod, prod,
reciprocal,
sqrt, sqrt,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -1879,7 +1879,7 @@ def test_dnn_batchnorm_train(): ...@@ -1879,7 +1879,7 @@ def test_dnn_batchnorm_train():
axes = (0,) + tuple(range(2, ndim)) axes = (0,) + tuple(range(2, ndim))
x_mean_ref = x.mean(axis=axes, keepdims=True) x_mean_ref = x.mean(axis=axes, keepdims=True)
x_var_ref = x.var(axis=axes, keepdims=True) x_var_ref = x.var(axis=axes, keepdims=True)
x_invstd_ref = inv(sqrt(x_var_ref + eps)) x_invstd_ref = reciprocal(sqrt(x_var_ref + eps))
scale_ref = aet.addbroadcast(scale, *axes) scale_ref = aet.addbroadcast(scale, *axes)
bias_ref = aet.addbroadcast(bias, *axes) bias_ref = aet.addbroadcast(bias, *axes)
m = aet.cast(prod(x.shape) / prod(scale.shape), aesara.config.floatX) m = aet.cast(prod(x.shape) / prod(scale.shape), aesara.config.floatX)
......
...@@ -46,7 +46,6 @@ from aesara.scalar.basic import ( ...@@ -46,7 +46,6 @@ from aesara.scalar.basic import (
int8, int8,
int32, int32,
ints, ints,
inv,
invert, invert,
log, log,
log1p, log1p,
...@@ -55,6 +54,7 @@ from aesara.scalar.basic import ( ...@@ -55,6 +54,7 @@ from aesara.scalar.basic import (
mul, mul,
neq, neq,
rad2deg, rad2deg,
reciprocal,
sin, sin,
sinh, sinh,
sqrt, sqrt,
...@@ -281,7 +281,7 @@ class TestUpgradeToFloat: ...@@ -281,7 +281,7 @@ class TestUpgradeToFloat:
# at least float32, not float16. # at least float32, not float16.
unary_ops_vals = [ unary_ops_vals = [
(inv, list(range(-127, 0)) + list(range(1, 127))), (reciprocal, list(range(-127, 0)) + list(range(1, 127))),
(sqrt, list(range(0, 128))), (sqrt, list(range(0, 128))),
(log, list(range(1, 128))), (log, list(range(1, 128))),
(log2, list(range(1, 128))), (log2, list(range(1, 128))),
......
...@@ -219,7 +219,7 @@ def test_batch_normalization_train(): ...@@ -219,7 +219,7 @@ def test_batch_normalization_train():
axes2 = axes axes2 = axes
x_mean2 = x.mean(axis=axes2, keepdims=True) x_mean2 = x.mean(axis=axes2, keepdims=True)
x_var2 = x.var(axis=axes2, keepdims=True) x_var2 = x.var(axis=axes2, keepdims=True)
x_invstd2 = aet.inv(aet.sqrt(x_var2 + eps)) x_invstd2 = aet.reciprocal(aet.sqrt(x_var2 + eps))
scale2 = aet.addbroadcast(scale, *axes2) scale2 = aet.addbroadcast(scale, *axes2)
bias2 = aet.addbroadcast(bias, *axes2) bias2 = aet.addbroadcast(bias, *axes2)
out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2 out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
......
...@@ -66,7 +66,6 @@ from aesara.tensor.math import ( ...@@ -66,7 +66,6 @@ from aesara.tensor.math import (
ge, ge,
gt, gt,
int_div, int_div,
inv,
invert, invert,
iround, iround,
le, le,
...@@ -81,6 +80,7 @@ from aesara.tensor.math import ( ...@@ -81,6 +80,7 @@ from aesara.tensor.math import (
neq, neq,
) )
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import reciprocal
from aesara.tensor.math import round as aet_round from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -845,7 +845,7 @@ class TestFusion: ...@@ -845,7 +845,7 @@ class TestFusion:
"float32", "float32",
), ),
( (
fx - fy + inv(fz), fx - fy + reciprocal(fz),
(fx, fy, fz), (fx, fy, fz),
(fxv, fyv, fzv), (fxv, fyv, fzv),
1, 1,
......
...@@ -24,7 +24,6 @@ from aesara.tensor.inplace import ( ...@@ -24,7 +24,6 @@ from aesara.tensor.inplace import (
expm1_inplace, expm1_inplace,
floor_inplace, floor_inplace,
int_div_inplace, int_div_inplace,
inv_inplace,
log1p_inplace, log1p_inplace,
log2_inplace, log2_inplace,
log10_inplace, log10_inplace,
...@@ -36,6 +35,7 @@ from aesara.tensor.inplace import ( ...@@ -36,6 +35,7 @@ from aesara.tensor.inplace import (
neg_inplace, neg_inplace,
pow_inplace, pow_inplace,
rad2deg_inplace, rad2deg_inplace,
reciprocal_inplace,
round_half_away_from_zero_inplace, round_half_away_from_zero_inplace,
round_half_to_even_inplace, round_half_to_even_inplace,
sgn_inplace, sgn_inplace,
...@@ -55,7 +55,7 @@ from tests import unittest_tools as utt ...@@ -55,7 +55,7 @@ from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_bad_build_broadcast_binary_normal, _bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal,
_bad_runtime_inv, _bad_runtime_reciprocal,
_good_broadcast_binary_arctan2, _good_broadcast_binary_arctan2,
_good_broadcast_binary_normal, _good_broadcast_binary_normal,
_good_broadcast_div_mod_normal_float_inplace, _good_broadcast_div_mod_normal_float_inplace,
...@@ -72,7 +72,7 @@ from tests.tensor.utils import ( ...@@ -72,7 +72,7 @@ from tests.tensor.utils import (
_good_broadcast_unary_positive_float, _good_broadcast_unary_positive_float,
_good_broadcast_unary_tan, _good_broadcast_unary_tan,
_good_broadcast_unary_wide_float, _good_broadcast_unary_wide_float,
_good_inv_inplace, _good_reciprocal_inplace,
_numpy_true_div, _numpy_true_div,
angle_eps, angle_eps,
check_floatX, check_floatX,
...@@ -142,11 +142,11 @@ TestTrueDivInplaceBroadcast = makeBroadcastTester( ...@@ -142,11 +142,11 @@ TestTrueDivInplaceBroadcast = makeBroadcastTester(
inplace=True, inplace=True,
) )
TestInvInplaceBroadcast = makeBroadcastTester( TestReciprocalInplaceBroadcast = makeBroadcastTester(
op=inv_inplace, op=reciprocal_inplace,
expected=lambda x: _numpy_true_div(np.int8(1), x), expected=lambda x: _numpy_true_div(np.int8(1), x),
good=_good_inv_inplace, good=_good_reciprocal_inplace,
bad_runtime=_bad_runtime_inv, bad_runtime=_bad_runtime_reciprocal,
grad_rtol=div_grad_rtol, grad_rtol=div_grad_rtol,
inplace=True, inplace=True,
) )
......
...@@ -67,7 +67,6 @@ from aesara.tensor.math import ( ...@@ -67,7 +67,6 @@ from aesara.tensor.math import (
exp2, exp2,
expm1, expm1,
floor, floor,
inv,
isclose, isclose,
isinf, isinf,
isnan, isnan,
...@@ -90,6 +89,7 @@ from aesara.tensor.math import ( ...@@ -90,6 +89,7 @@ from aesara.tensor.math import (
power, power,
ptp, ptp,
rad2deg, rad2deg,
reciprocal,
round_half_away_from_zero, round_half_away_from_zero,
round_half_to_even, round_half_to_even,
sgn, sgn,
...@@ -136,7 +136,7 @@ from tests import unittest_tools as utt ...@@ -136,7 +136,7 @@ from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_bad_build_broadcast_binary_normal, _bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal,
_bad_runtime_inv, _bad_runtime_reciprocal,
_eps, _eps,
_good_broadcast_binary_arctan2, _good_broadcast_binary_arctan2,
_good_broadcast_binary_normal, _good_broadcast_binary_normal,
...@@ -153,14 +153,14 @@ from tests.tensor.utils import ( ...@@ -153,14 +153,14 @@ from tests.tensor.utils import (
_good_broadcast_unary_positive, _good_broadcast_unary_positive,
_good_broadcast_unary_tan, _good_broadcast_unary_tan,
_good_broadcast_unary_wide, _good_broadcast_unary_wide,
_good_inv, _good_reciprocal,
_grad_broadcast_binary_normal, _grad_broadcast_binary_normal,
_grad_broadcast_pow_normal, _grad_broadcast_pow_normal,
_grad_broadcast_unary_normal, _grad_broadcast_unary_normal,
_grad_broadcast_unary_normal_no_complex, _grad_broadcast_unary_normal_no_complex,
_grad_broadcast_unary_normal_no_complex_no_corner_case, _grad_broadcast_unary_normal_no_complex_no_corner_case,
_grad_broadcast_unary_normal_noint, _grad_broadcast_unary_normal_noint,
_grad_inv, _grad_reciprocal,
_numpy_true_div, _numpy_true_div,
angle_eps, angle_eps,
check_floatX, check_floatX,
...@@ -308,11 +308,11 @@ TestTrueDivBroadcast = makeBroadcastTester( ...@@ -308,11 +308,11 @@ TestTrueDivBroadcast = makeBroadcastTester(
) )
TestInvBroadcast = makeBroadcastTester( TestInvBroadcast = makeBroadcastTester(
op=inv, op=reciprocal,
expected=lambda x: upcast_int8_nfunc(np.true_divide)(np.int8(1), x), expected=lambda x: upcast_int8_nfunc(np.true_divide)(np.int8(1), x),
good=_good_inv, good=_good_reciprocal,
bad_runtime=_bad_runtime_inv, bad_runtime=_bad_runtime_reciprocal,
grad=_grad_inv, grad=_grad_reciprocal,
grad_rtol=div_grad_rtol, grad_rtol=div_grad_rtol,
) )
......
...@@ -52,7 +52,6 @@ from aesara.tensor.math import ( ...@@ -52,7 +52,6 @@ from aesara.tensor.math import (
ge, ge,
gt, gt,
int_div, int_div,
inv,
invert, invert,
iround, iround,
le, le,
...@@ -67,7 +66,7 @@ from aesara.tensor.math import maximum ...@@ -67,7 +66,7 @@ from aesara.tensor.math import maximum
from aesara.tensor.math import min as aet_min from aesara.tensor.math import min as aet_min
from aesara.tensor.math import minimum, mul, neg, neq from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, rad2deg from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as aet_round from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, sqr, sqrt, sub from aesara.tensor.math import sgn, sigmoid, sin, sinh, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -595,9 +594,9 @@ class TestAlgebraicCanonize: ...@@ -595,9 +594,9 @@ class TestAlgebraicCanonize:
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation # must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"), ((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"), ((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
] ]
): ):
f = function(list(sym_inputs), g, mode=mode) f = function(list(sym_inputs), g, mode=mode)
...@@ -609,7 +608,7 @@ class TestAlgebraicCanonize: ...@@ -609,7 +608,7 @@ class TestAlgebraicCanonize:
assert isinstance(elem[0].op, (Elemwise,)) assert isinstance(elem[0].op, (Elemwise,))
assert isinstance( assert isinstance(
elem[0].op.scalar_op, elem[0].op.scalar_op,
(aes.basic.Inv, aes.basic.TrueDiv), (aes.basic.Reciprocal, aes.basic.TrueDiv),
) )
assert out_dtype == out.dtype assert out_dtype == out.dtype
...@@ -912,7 +911,7 @@ class TestAlgebraicCanonize: ...@@ -912,7 +911,7 @@ class TestAlgebraicCanonize:
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op, (Elemwise,)) assert isinstance(topo[0].op, (Elemwise,))
assert isinstance(topo[0].op.scalar_op, aes.basic.Inv) assert isinstance(topo[0].op.scalar_op, aes.basic.Reciprocal)
assert len(topo[0].inputs) == 1 assert len(topo[0].inputs) == 1
assert out_dtype == out.dtype assert out_dtype == out.dtype
...@@ -927,7 +926,7 @@ class TestAlgebraicCanonize: ...@@ -927,7 +926,7 @@ class TestAlgebraicCanonize:
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op, (Elemwise,)) assert isinstance(topo[0].op, (Elemwise,))
assert isinstance(topo[0].op.scalar_op, aes.basic.Inv) assert isinstance(topo[0].op.scalar_op, aes.basic.Reciprocal)
assert len(topo[0].inputs) == 1 assert len(topo[0].inputs) == 1
assert out_dtype == out.dtype assert out_dtype == out.dtype
...@@ -1545,7 +1544,7 @@ class TestFusion: ...@@ -1545,7 +1544,7 @@ class TestFusion:
"float32", "float32",
), ),
( (
fx - fy + inv(fz), fx - fy + reciprocal(fz),
(fx, fy, fz), (fx, fy, fz),
(fxv, fyv, fzv), (fxv, fyv, fzv),
1, 1,
...@@ -2360,7 +2359,7 @@ def test_local_pow_specialize(): ...@@ -2360,7 +2359,7 @@ def test_local_pow_specialize():
f = function([v], v ** (-1), mode=mode) f = function([v], v ** (-1), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [inv] assert nodes == [reciprocal]
utt.assert_allclose(f(val_no0), val_no0 ** (-1)) utt.assert_allclose(f(val_no0), val_no0 ** (-1))
f = function([v], v ** 2, mode=mode) f = function([v], v ** 2, mode=mode)
...@@ -2372,7 +2371,7 @@ def test_local_pow_specialize(): ...@@ -2372,7 +2371,7 @@ def test_local_pow_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2 assert len(nodes) == 2
assert nodes[0] == sqr assert nodes[0] == sqr
assert isinstance(nodes[1].scalar_op, aes.basic.Inv) assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-2)) utt.assert_allclose(f(val_no0), val_no0 ** (-2))
f = function([v], v ** (0.5), mode=mode) f = function([v], v ** (0.5), mode=mode)
...@@ -2384,7 +2383,7 @@ def test_local_pow_specialize(): ...@@ -2384,7 +2383,7 @@ def test_local_pow_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2 assert len(nodes) == 2
assert nodes[0] == sqrt assert nodes[0] == sqrt
assert isinstance(nodes[1].scalar_op, aes.basic.Inv) assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5)) utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
...@@ -2410,7 +2409,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu(): ...@@ -2410,7 +2409,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
assert len(nodes) == 2 assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 6 assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 6
assert isinstance(nodes[0].scalar_op, aes.Composite) assert isinstance(nodes[0].scalar_op, aes.Composite)
assert isinstance(nodes[-1].scalar_op, aes.basic.Inv) assert isinstance(nodes[-1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-15)) utt.assert_allclose(f(val_no0), val_no0 ** (-15))
f = function([v], v ** (16), mode=mode) f = function([v], v ** (16), mode=mode)
...@@ -2425,7 +2424,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu(): ...@@ -2425,7 +2424,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
assert len(nodes) == 2 assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 4 assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 4
assert isinstance(nodes[0].scalar_op, aes.Composite) assert isinstance(nodes[0].scalar_op, aes.Composite)
assert isinstance(nodes[-1].scalar_op, aes.basic.Inv) assert isinstance(nodes[-1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-16)) utt.assert_allclose(f(val_no0), val_no0 ** (-16))
...@@ -2475,7 +2474,7 @@ class TestFuncInverse: ...@@ -2475,7 +2474,7 @@ class TestFuncInverse:
self.assert_func_pair_optimized(cosh, arccosh, dx) self.assert_func_pair_optimized(cosh, arccosh, dx)
self.assert_func_pair_optimized(arcsinh, sinh, dx) self.assert_func_pair_optimized(arcsinh, sinh, dx)
self.assert_func_pair_optimized(arctanh, tanh, dx) self.assert_func_pair_optimized(arctanh, tanh, dx)
self.assert_func_pair_optimized(inv, inv, dx) self.assert_func_pair_optimized(reciprocal, reciprocal, dx)
self.assert_func_pair_optimized(neg, neg, dx) self.assert_func_pair_optimized(neg, neg, dx)
cx = dx + complex(0, 1) * (dx + 0.01) cx = dx + complex(0, 1) * (dx + 0.01)
self.assert_func_pair_optimized(conj, conj, cx, is_complex=True) self.assert_func_pair_optimized(conj, conj, cx, is_complex=True)
...@@ -2826,12 +2825,13 @@ class TestLocalErfc: ...@@ -2826,12 +2825,13 @@ class TestLocalErfc:
def test_local_log_erfc(self): def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30] val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]: if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
# python mode don't like the inv(0) # python mode doesn't like the reciprocal(0)
val.remove(0) val.remove(0)
val = np.asarray(val, dtype=config.floatX) val = np.asarray(val, dtype=config.floatX)
x = vector("x") x = vector("x")
# their is some nan that will happear in the graph for the log of the negatives values # their are some `nan`s that will appear in the graph due to the logs
# of negatives values
mode = copy.copy(self.mode) mode = copy.copy(self.mode)
mode.check_isfinite = False mode.check_isfinite = False
mode_fusion = copy.copy(self.mode_fusion) mode_fusion = copy.copy(self.mode_fusion)
...@@ -3761,7 +3761,8 @@ def test_local_add_specialize(): ...@@ -3761,7 +3761,8 @@ def test_local_add_specialize():
assert transformed[0].type == s.type assert transformed[0].type == s.type
def test_local_div_to_inv(): def test_local_div_to_reciprocal():
# XXX TODO: This does *not* test `local_div_to_reciprocal`!
num_len_s = lscalar("num_len") num_len_s = lscalar("num_len")
denom_s = scalar("denom") denom_s = scalar("denom")
......
...@@ -840,7 +840,7 @@ _grad_broadcast_binary_normal = dict( ...@@ -840,7 +840,7 @@ _grad_broadcast_binary_normal = dict(
# complex3=(rand(2,3),randcomplex(2,3)), # complex3=(rand(2,3),randcomplex(2,3)),
) )
_good_inv = dict( _good_reciprocal = dict(
normal=[5 * rand_nonzero((2, 3))], normal=[5 * rand_nonzero((2, 3))],
integers=[randint_nonzero(2, 3)], integers=[randint_nonzero(2, 3)],
int8=[np.array(list(range(-127, 0)) + list(range(1, 127)), dtype="int8")], int8=[np.array(list(range(-127, 0)) + list(range(1, 127)), dtype="int8")],
...@@ -850,14 +850,15 @@ _good_inv = dict( ...@@ -850,14 +850,15 @@ _good_inv = dict(
empty=[np.asarray([], dtype=config.floatX)], empty=[np.asarray([], dtype=config.floatX)],
) )
_good_inv_inplace = copymod( _good_reciprocal_inplace = copymod(
_good_inv, without=["integers", "int8", "uint8", "uint16", "complex"] _good_reciprocal, without=["integers", "int8", "uint8", "uint16", "complex"]
) )
_grad_inv = copymod( _grad_reciprocal = copymod(
_good_inv, without=["integers", "int8", "uint8", "uint16", "complex", "empty"] _good_reciprocal,
without=["integers", "int8", "uint8", "uint16", "complex", "empty"],
) )
_bad_runtime_inv = dict( _bad_runtime_reciprocal = dict(
float=[np.zeros((2, 3))], float=[np.zeros((2, 3))],
integers=[np.zeros((2, 3), dtype="int64")], integers=[np.zeros((2, 3), dtype="int64")],
int8=[np.zeros((2, 3), dtype="int8")], int8=[np.zeros((2, 3), dtype="int8")],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论