提交 deb437d8 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Rename scalar Minimum and Maximum Ops to ScalarMinimum and ScalarMaximum

Closes #55
上级 91ffe60b
......@@ -458,7 +458,7 @@ class TestGpuCAReduceCuda(TestGpuCAReduceCPY):
# ((5,4,3,10,11),[1,2]),
]
op = GpuCAReduceCuda
reds = [scalar.add, scalar.mul, scalar.maximum, scalar.minimum]
reds = [scalar.add, scalar.mul, scalar.scalar_maximum, scalar.scalar_minimum]
pre_scalar_op = None
def test_perform_noopt(self):
......
......@@ -1068,7 +1068,9 @@ def test_argmax_pushdown():
assert isinstance(fgraph.toposort()[0].op, tt.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tt.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert isinstance(
fgraph.toposort()[2].op.scalar_op, theano.scalar.ScalarMaximum
)
def test_argmax_pushdown_bias():
......@@ -1098,7 +1100,7 @@ def test_argmax_pushdown_bias():
assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tt.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.ScalarMaximum)
assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, tt.CAReduce))
......
......@@ -458,13 +458,13 @@ class TestCAReduce(unittest_tools.InferShapeTester):
elif scalar_op == scalar.mul:
for axis in reversed(sorted(tosum)):
zv = np.multiply.reduce(zv, axis)
elif scalar_op == scalar.maximum:
elif scalar_op == scalar.scalar_maximum:
try:
for axis in reversed(sorted(tosum)):
zv = np.maximum.reduce(zv, axis)
except ValueError:
numpy_raised = True
elif scalar_op == scalar.minimum:
elif scalar_op == scalar.scalar_minimum:
try:
for axis in reversed(sorted(tosum)):
zv = np.minimum.reduce(zv, axis)
......@@ -487,7 +487,10 @@ class TestCAReduce(unittest_tools.InferShapeTester):
raise Exception(
f"Test for CAReduce with scalar_op {scalar_op} not implemented"
)
if scalar_op in [scalar.maximum, scalar.minimum] and numpy_raised:
if (
scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]
and numpy_raised
):
with pytest.raises(ValueError):
f(xv)
else:
......@@ -515,7 +518,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
tosum = list(range(len(xsh)))
f = theano.function([x], e.shape, mode=mode)
if not (
scalar_op in [scalar.maximum, scalar.minimum]
scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]
and (xsh == () or np.prod(xsh) == 0)
):
try:
......@@ -531,8 +534,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.minimum, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.scalar_minimum, dtype=dtype)
self.with_mode(
Mode(linker="py"), scalar.and_, dtype=dtype, tensor_op=tt.all
)
......@@ -547,10 +550,10 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype, test_nan=True)
self.with_mode(
Mode(linker="py"), scalar.maximum, dtype=dtype, test_nan=True
Mode(linker="py"), scalar.scalar_maximum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="py"), scalar.minimum, dtype=dtype, test_nan=True
Mode(linker="py"), scalar.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="py"),
......@@ -584,8 +587,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype)
for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_mode(Mode(linker="c"), scalar.minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.scalar_minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.and_, dtype=dtype, tensor_op=tt.all)
self.with_mode(Mode(linker="c"), scalar.or_, dtype=dtype, tensor_op=tt.any)
for dtype in ["bool", "int8", "uint8"]:
......@@ -602,8 +605,12 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype, test_nan=True)
for dtype in ["floatX"]:
self.with_mode(Mode(linker="c"), scalar.minimum, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), scalar.maximum, dtype=dtype, test_nan=True)
self.with_mode(
Mode(linker="c"), scalar.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="c"), scalar.scalar_maximum, dtype=dtype, test_nan=True
)
def test_infer_shape(self, dtype=None, pre_scalar_op=None):
if dtype is None:
......
......@@ -8063,7 +8063,10 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if hasattr(node.op, "scalar_op") and node.op.scalar_op == scal.basic.maximum:
if (
hasattr(node.op, "scalar_op")
and node.op.scalar_op == scal.basic.scalar_maximum
):
return
# in mode FAST_COMPILE, the optimisations don't replace the
......
......@@ -325,7 +325,7 @@ def test_scan_debugprint2():
expected_output = """Sum{acc_dtype=float64} [id A] ''
|for{cpu,scan_fn} [id B] ''
|Elemwise{minimum,no_inplace} [id C] ''
|Elemwise{scalar_minimum,no_inplace} [id C] ''
| |Subtensor{int64} [id D] ''
| | |Shape [id E] ''
| | | |Subtensor{int64::} [id F] 'coefficients[0:]'
......@@ -344,12 +344,12 @@ def test_scan_debugprint2():
|Subtensor{:int64:} [id S] ''
| |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] ''
| |Elemwise{minimum,no_inplace} [id C] ''
| |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Subtensor{:int64:} [id U] ''
| |Subtensor{int64::} [id L] ''
| |ScalarFromTensor [id V] ''
| |Elemwise{minimum,no_inplace} [id C] ''
|Elemwise{minimum,no_inplace} [id C] ''
| |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Elemwise{scalar_minimum,no_inplace} [id C] ''
|x [id W]
Inner graphs of the scan ops:
......@@ -404,7 +404,7 @@ def test_scan_debugprint3():
expected_output = """Sum{acc_dtype=float64} [id A] ''
|for{cpu,scan_fn} [id B] ''
|Elemwise{minimum,no_inplace} [id C] ''
|Elemwise{scalar_minimum,no_inplace} [id C] ''
| |Subtensor{int64} [id D] ''
| | |Shape [id E] ''
| | | |Subtensor{int64::} [id F] 'coefficients[0:]'
......@@ -423,12 +423,12 @@ def test_scan_debugprint3():
|Subtensor{:int64:} [id S] ''
| |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] ''
| |Elemwise{minimum,no_inplace} [id C] ''
| |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Subtensor{:int64:} [id U] ''
| |Subtensor{int64::} [id L] ''
| |ScalarFromTensor [id V] ''
| |Elemwise{minimum,no_inplace} [id C] ''
|Elemwise{minimum,no_inplace} [id C] ''
| |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Elemwise{scalar_minimum,no_inplace} [id C] ''
|A [id W]
|k [id X]
......
......@@ -1532,8 +1532,8 @@ class ProfileStats:
scal.XOR,
scal.AND,
scal.Invert,
scal.Maximum,
scal.Minimum,
scal.ScalarMaximum,
scal.ScalarMinimum,
scal.Add,
scal.Mul,
scal.Sub,
......
......@@ -738,9 +738,9 @@ def local_dnn_reduction(fgraph, node):
scal = "norm1"
else:
return
elif isinstance(node.op.scalar_op, theano.scalar.basic.Maximum) and isinstance(
node.op.pre_scalar_op, theano.scalar.basic.Abs
):
elif isinstance(
node.op.scalar_op, theano.scalar.basic.ScalarMaximum
) and isinstance(node.op.pre_scalar_op, theano.scalar.basic.Abs):
scal = "absmax"
else:
return
......
......@@ -703,7 +703,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
# It might be nice to use a property of the op class to do this,
# but tensor.elemwise.CAReduce has this exact same check so I guess
# this is OK to do
if self.scalar_op in [scalar.minimum, scalar.maximum]:
if self.scalar_op in [scalar.scalar_minimum, scalar.scalar_maximum]:
conds = [
f"(PyGpuArray_DIMS({x})[{i}] == 0)"
for i in range(nd_in)
......@@ -1083,7 +1083,9 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
if hasattr(self.scalar_op, "identity"):
return str(self.scalar_op.identity)
else:
assert isinstance(self.scalar_op, (scalar.Maximum, scalar.Minimum))
assert isinstance(
self.scalar_op, (scalar.ScalarMaximum, scalar.ScalarMinimum)
)
if self.pre_scalar_op: # TODO: multiple dtypes
# dtype = node.inputs[0].dtype
......
......@@ -1205,7 +1205,8 @@ def local_gpu_extract_diag(fgraph, op, context_name, inputs, outputs):
@register_opt2([tensor.CAReduce, tensor.Sum, tensor.elemwise.Prod], "fast_compile")
def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
if isinstance(
op.scalar_op, (scalar.Add, scalar.Mul, scalar.Maximum, scalar.Minimum)
op.scalar_op,
(scalar.Add, scalar.Mul, scalar.ScalarMaximum, scalar.ScalarMinimum),
):
ctx = get_context(context_name)
......
......@@ -1729,7 +1729,7 @@ invert = Invert()
##############
# Arithmetic
##############
class Maximum(BinaryScalarOp):
class ScalarMaximum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("maximum", 2, 1)
......@@ -1768,10 +1768,10 @@ class Maximum(BinaryScalarOp):
return (gx, gy)
maximum = Maximum(upcast_out, name="maximum")
scalar_maximum = ScalarMaximum(upcast_out, name="maximum")
class Minimum(BinaryScalarOp):
class ScalarMinimum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("minimum", 2, 1)
......@@ -1809,7 +1809,7 @@ class Minimum(BinaryScalarOp):
return (gx, gy)
minimum = Minimum(upcast_out, name="minimum")
scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
class Add(ScalarOp):
......
......@@ -380,7 +380,7 @@ def numpy_scalar(data):
)
get_scalar_constant_value_elemwises = (
_scalar_constant_value_elemwise_ops = (
scal.Cast,
scal.Switch,
scal.NEQ,
......@@ -395,8 +395,8 @@ get_scalar_constant_value_elemwises = (
scal.Mul,
scal.IntDiv,
scal.TrueDiv,
scal.Minimum,
scal.Maximum,
scal.ScalarMinimum,
scal.ScalarMaximum,
)
......@@ -502,7 +502,7 @@ def get_scalar_constant_value(
shp, val = v.owner.inputs
v = val
continue
if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
......@@ -520,7 +520,7 @@ def get_scalar_constant_value(
v = val
continue
elif elemwise and isinstance(
v.owner.op.scalar_op, get_scalar_constant_value_elemwises
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
......@@ -1042,7 +1042,7 @@ elemwise.TensorConstant = TensorConstant
#########################
def _scal_elemwise_with_nfunc(nfunc, nin, nout):
def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
"""
Replace a symbol definition with an elementwise version of the
corresponding scalar Op. If it is not None, the nfunc argument
......@@ -1056,48 +1056,47 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
"""
def construct(symbol):
symbolname = symbol.__name__
inplace = symbolname.endswith("_inplace")
if inplace:
msg = "inplace"
else:
msg = "no_inplace"
nonlocal symbolname
n = f"Elemwise{{{symbolname},{msg}}}"
symbolname = symbolname or symbol.__name__
if inplace:
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(
inplace_scalar_op,
{0: 0},
name=n,
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(
scalar_op, name=n, nfunc_spec=(nfunc and (nfunc, nin, nout))
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__", False):
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = "tensor"
rval.__module__ = symbol.__module__
pprint.assign(rval, printing.FunctionPrinter(symbolname))
pprint.assign(
rval, printing.FunctionPrinter(symbolname.replace("_inplace", "="))
)
return rval
if symbol:
return construct(symbol[0])
else:
return construct
_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None)
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
......@@ -1772,14 +1771,14 @@ class Max(CAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(scal.maximum, axis)
super().__init__(scal.scalar_maximum, axis)
class Min(CAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(scal.minimum, axis)
super().__init__(scal.scalar_minimum, axis)
@constructor
......@@ -3661,13 +3660,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise
@_scal_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise(symbolname="scalar_minimum")
def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body
......
......@@ -1348,9 +1348,9 @@ class CAReduce(COp):
self.ufunc = np.add
elif isinstance(scalar_op, theano.scalar.basic.Mul):
self.ufunc = np.multiply
elif isinstance(scalar_op, theano.scalar.basic.Maximum):
elif isinstance(scalar_op, theano.scalar.basic.ScalarMaximum):
self.ufunc = np.maximum
elif isinstance(scalar_op, theano.scalar.basic.Minimum):
elif isinstance(scalar_op, theano.scalar.basic.ScalarMinimum):
self.ufunc = np.minimum
elif isinstance(scalar_op, theano.scalar.basic.AND) and _numpy_ver >= [1, 12]:
# numpy.bitwise_and.identity was incorrect for versions before
......@@ -1570,8 +1570,8 @@ class CAReduce(COp):
if hasattr(self.scalar_op, "identity"):
identity = self.scalar_op.identity
elif self.scalar_op in [scalar.maximum, scalar.minimum]:
if self.scalar_op == scalar.maximum:
elif self.scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]:
if self.scalar_op == scalar.scalar_maximum:
scal_name = "maximum"
if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()"
......@@ -1580,7 +1580,7 @@ class CAReduce(COp):
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
if self.scalar_op == scalar.minimum:
if self.scalar_op == scalar.scalar_minimum:
scal_name = "minimum"
if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()"
......
from theano import printing
from theano import scalar as scal
from theano.printing import pprint
from theano.tensor import elemwise
from theano.tensor.basic import _scal_elemwise
from . import elemwise
def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__
inplace = symbolname.endswith("_inplace")
if inplace:
scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=symbolname)
if getattr(symbol, "__doc__", False):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = "theano.tensor.inplace"
pprint.assign(rval, printing.FunctionPrinter(symbolname.replace("_inplace", "=")))
return rval
@_scal_inplace
@_scal_elemwise
def lt_inplace(a, b):
"""a < b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def gt_inplace(a, b):
"""a > b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def le_inplace(a, b):
"""a <= b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def ge_inplace(a, b):
"""a >= b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def eq_inplace(a, b):
"""a == b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def neq_inplace(a, b):
"""a != b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def and__inplace(a, b):
"""bitwise a & b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def or__inplace(a, b):
"""bitwise a | b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def xor_inplace(a, b):
"""bitwise a ^ b (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def abs__inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def exp2_inplace(a):
"""2^`a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def expm1_inplace(a):
"""e^`a` - 1 (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def log1p_inplace(a):
"""log(1+a)"""
@_scal_inplace
@_scal_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def log10_inplace(a):
"""base 10 logarithm of a (inplace on a)"""
@_scal_inplace
@_scal_elemwise
def sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def ceil_inplace(a):
"""ceil of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def floor_inplace(a):
"""floor of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def trunc_inplace(a):
"""trunc of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def round_half_to_even_inplace(a):
"""round_half_to_even_inplace(a) (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def round_half_away_from_zero_inplace(a):
"""round_half_away_from_zero_inplace(a) (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def deg2rad_inplace(a):
"""convert degree `a` to radian(inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def rad2deg_inplace(a):
"""convert radian `a` to degree(inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arccos_inplace(a):
"""arccosine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def erf_inplace(a):
"""error function"""
@_scal_inplace
@_scal_elemwise
def erfc_inplace(a):
"""complementary error function"""
@_scal_inplace
@_scal_elemwise
def erfcx_inplace(a):
"""scaled complementary error function"""
@_scal_inplace
@_scal_elemwise
def gamma_inplace(a):
"""gamma function"""
@_scal_inplace
@_scal_elemwise
def gammaln_inplace(a):
"""log gamma function"""
@_scal_inplace
@_scal_elemwise
def psi_inplace(a):
"""derivative of log gamma function"""
@_scal_inplace
@_scal_elemwise
def tri_gamma_inplace(a):
"""second derivative of the log gamma function"""
@_scal_inplace
@_scal_elemwise
def chi2sf_inplace(x, k):
"""chi squared survival function"""
@_scal_inplace
@_scal_elemwise
def j0_inplace(x):
"""Bessel function of the first kind of order 0."""
@_scal_inplace
@_scal_elemwise
def j1_inplace(x):
"""Bessel function of the first kind of order 1."""
@_scal_inplace
@_scal_elemwise
def jv_inplace(v, x):
"""Bessel function of the first kind of order v (real)."""
@_scal_inplace
@_scal_elemwise
def i0_inplace(x):
"""Modified Bessel function of the first kind of order 0."""
@_scal_inplace
@_scal_elemwise
def i1_inplace(x):
"""Modified Bessel function of the first kind of order 1."""
@_scal_inplace
@_scal_elemwise
def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
@_scal_inplace
@_scal_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
......@@ -324,52 +298,52 @@ fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter("fill="))
@_scal_inplace
@_scal_elemwise(symbolname="scalar_maximum_inplace")
def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise(symbolname="scalar_minimum_inplace")
def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def true_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def int_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
@_scal_inplace
@_scal_elemwise
def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""
......
......@@ -5624,7 +5624,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [res]
# Elemwise[{minimum,maximum}](X, X) -> X
if (
isinstance(node.op.scalar_op, (ts.Minimum, ts.Maximum))
isinstance(node.op.scalar_op, (ts.ScalarMinimum, ts.ScalarMaximum))
and node.inputs[0] is node.inputs[1]
):
res = node.inputs[0]
......@@ -5656,7 +5656,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [res]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if (
isinstance(node.op.scalar_op, ts.Maximum)
isinstance(node.op.scalar_op, ts.ScalarMaximum)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0
......@@ -5665,7 +5665,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if (
isinstance(node.op.scalar_op, ts.Maximum)
isinstance(node.op.scalar_op, ts.ScalarMaximum)
and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i)
......@@ -5674,7 +5674,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [node.inputs[1]]
# Elemwise[minimum](X.shape[i], 0) -> 0
if (
isinstance(node.op.scalar_op, ts.Minimum)
isinstance(node.op.scalar_op, ts.ScalarMinimum)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0
......@@ -5686,7 +5686,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Elemwise[minimum](0, X.shape[i]) -> 0
if (
isinstance(node.op.scalar_op, ts.Minimum)
isinstance(node.op.scalar_op, ts.ScalarMinimum)
and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i)
......@@ -6039,7 +6039,7 @@ def local_reduce_join(fgraph, node):
if tt.extract_constant(join.inputs[0], only_process_constants=True) != 0:
return
if isinstance(node.op.scalar_op, (ts.Maximum, ts.Minimum)):
if isinstance(node.op.scalar_op, (ts.ScalarMaximum, ts.ScalarMinimum)):
# Support only 2 inputs for now
if len(join.inputs) != 3:
return
......
......@@ -82,7 +82,7 @@ def local_max_to_min(fgraph, node):
if (
max.owner
and isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == scal.maximum
and max.owner.op.scalar_op == scal.scalar_maximum
):
neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == tt.neg:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论