提交 deb437d8 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Rename scalar Minimum and Maximum Ops to ScalarMinimum and ScalarMaximum

Closes #55
上级 91ffe60b
...@@ -458,7 +458,7 @@ class TestGpuCAReduceCuda(TestGpuCAReduceCPY): ...@@ -458,7 +458,7 @@ class TestGpuCAReduceCuda(TestGpuCAReduceCPY):
# ((5,4,3,10,11),[1,2]), # ((5,4,3,10,11),[1,2]),
] ]
op = GpuCAReduceCuda op = GpuCAReduceCuda
reds = [scalar.add, scalar.mul, scalar.maximum, scalar.minimum] reds = [scalar.add, scalar.mul, scalar.scalar_maximum, scalar.scalar_minimum]
pre_scalar_op = None pre_scalar_op = None
def test_perform_noopt(self): def test_perform_noopt(self):
......
...@@ -1068,7 +1068,9 @@ def test_argmax_pushdown(): ...@@ -1068,7 +1068,9 @@ def test_argmax_pushdown():
assert isinstance(fgraph.toposort()[0].op, tt.Elemwise) assert isinstance(fgraph.toposort()[0].op, tt.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tt.CAReduce) assert isinstance(fgraph.toposort()[2].op, tt.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(
fgraph.toposort()[2].op.scalar_op, theano.scalar.ScalarMaximum
)
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
...@@ -1098,7 +1100,7 @@ def test_argmax_pushdown_bias(): ...@@ -1098,7 +1100,7 @@ def test_argmax_pushdown_bias():
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tt.CAReduce) assert isinstance(fgraph.toposort()[1].op, tt.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.ScalarMaximum)
assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, tt.CAReduce)) assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, tt.CAReduce))
......
...@@ -458,13 +458,13 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -458,13 +458,13 @@ class TestCAReduce(unittest_tools.InferShapeTester):
elif scalar_op == scalar.mul: elif scalar_op == scalar.mul:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = np.multiply.reduce(zv, axis) zv = np.multiply.reduce(zv, axis)
elif scalar_op == scalar.maximum: elif scalar_op == scalar.scalar_maximum:
try: try:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = np.maximum.reduce(zv, axis) zv = np.maximum.reduce(zv, axis)
except ValueError: except ValueError:
numpy_raised = True numpy_raised = True
elif scalar_op == scalar.minimum: elif scalar_op == scalar.scalar_minimum:
try: try:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = np.minimum.reduce(zv, axis) zv = np.minimum.reduce(zv, axis)
...@@ -487,7 +487,10 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -487,7 +487,10 @@ class TestCAReduce(unittest_tools.InferShapeTester):
raise Exception( raise Exception(
f"Test for CAReduce with scalar_op {scalar_op} not implemented" f"Test for CAReduce with scalar_op {scalar_op} not implemented"
) )
if scalar_op in [scalar.maximum, scalar.minimum] and numpy_raised: if (
scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]
and numpy_raised
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
f(xv) f(xv)
else: else:
...@@ -515,7 +518,7 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -515,7 +518,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
tosum = list(range(len(xsh))) tosum = list(range(len(xsh)))
f = theano.function([x], e.shape, mode=mode) f = theano.function([x], e.shape, mode=mode)
if not ( if not (
scalar_op in [scalar.maximum, scalar.minimum] scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]
and (xsh == () or np.prod(xsh) == 0) and (xsh == () or np.prod(xsh) == 0)
): ):
try: try:
...@@ -531,8 +534,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -531,8 +534,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype) self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype) self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.maximum, dtype=dtype) self.with_mode(Mode(linker="py"), scalar.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), scalar.minimum, dtype=dtype) self.with_mode(Mode(linker="py"), scalar.scalar_minimum, dtype=dtype)
self.with_mode( self.with_mode(
Mode(linker="py"), scalar.and_, dtype=dtype, tensor_op=tt.all Mode(linker="py"), scalar.and_, dtype=dtype, tensor_op=tt.all
) )
...@@ -547,10 +550,10 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -547,10 +550,10 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="py"), scalar.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="py"), scalar.mul, dtype=dtype, test_nan=True)
self.with_mode( self.with_mode(
Mode(linker="py"), scalar.maximum, dtype=dtype, test_nan=True Mode(linker="py"), scalar.scalar_maximum, dtype=dtype, test_nan=True
) )
self.with_mode( self.with_mode(
Mode(linker="py"), scalar.minimum, dtype=dtype, test_nan=True Mode(linker="py"), scalar.scalar_minimum, dtype=dtype, test_nan=True
) )
self.with_mode( self.with_mode(
Mode(linker="py"), Mode(linker="py"),
...@@ -584,8 +587,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -584,8 +587,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype) self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype) self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype)
for dtype in ["bool", "floatX", "int8", "uint8"]: for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_mode(Mode(linker="c"), scalar.minimum, dtype=dtype) self.with_mode(Mode(linker="c"), scalar.scalar_minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.maximum, dtype=dtype) self.with_mode(Mode(linker="c"), scalar.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), scalar.and_, dtype=dtype, tensor_op=tt.all) self.with_mode(Mode(linker="c"), scalar.and_, dtype=dtype, tensor_op=tt.all)
self.with_mode(Mode(linker="c"), scalar.or_, dtype=dtype, tensor_op=tt.any) self.with_mode(Mode(linker="c"), scalar.or_, dtype=dtype, tensor_op=tt.any)
for dtype in ["bool", "int8", "uint8"]: for dtype in ["bool", "int8", "uint8"]:
...@@ -602,8 +605,12 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -602,8 +605,12 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="c"), scalar.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="c"), scalar.mul, dtype=dtype, test_nan=True)
for dtype in ["floatX"]: for dtype in ["floatX"]:
self.with_mode(Mode(linker="c"), scalar.minimum, dtype=dtype, test_nan=True) self.with_mode(
self.with_mode(Mode(linker="c"), scalar.maximum, dtype=dtype, test_nan=True) Mode(linker="c"), scalar.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="c"), scalar.scalar_maximum, dtype=dtype, test_nan=True
)
def test_infer_shape(self, dtype=None, pre_scalar_op=None): def test_infer_shape(self, dtype=None, pre_scalar_op=None):
if dtype is None: if dtype is None:
......
...@@ -8063,7 +8063,10 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): ...@@ -8063,7 +8063,10 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
fgraph = f.maker.fgraph.toposort() fgraph = f.maker.fgraph.toposort()
for node in fgraph: for node in fgraph:
if hasattr(node.op, "scalar_op") and node.op.scalar_op == scal.basic.maximum: if (
hasattr(node.op, "scalar_op")
and node.op.scalar_op == scal.basic.scalar_maximum
):
return return
# in mode FAST_COMPILE, the optimisations don't replace the # in mode FAST_COMPILE, the optimisations don't replace the
......
...@@ -325,7 +325,7 @@ def test_scan_debugprint2(): ...@@ -325,7 +325,7 @@ def test_scan_debugprint2():
expected_output = """Sum{acc_dtype=float64} [id A] '' expected_output = """Sum{acc_dtype=float64} [id A] ''
|for{cpu,scan_fn} [id B] '' |for{cpu,scan_fn} [id B] ''
|Elemwise{minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
| |Subtensor{int64} [id D] '' | |Subtensor{int64} [id D] ''
| | |Shape [id E] '' | | |Shape [id E] ''
| | | |Subtensor{int64::} [id F] 'coefficients[0:]' | | | |Subtensor{int64::} [id F] 'coefficients[0:]'
...@@ -344,12 +344,12 @@ def test_scan_debugprint2(): ...@@ -344,12 +344,12 @@ def test_scan_debugprint2():
|Subtensor{:int64:} [id S] '' |Subtensor{:int64:} [id S] ''
| |Subtensor{int64::} [id F] 'coefficients[0:]' | |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] '' | |ScalarFromTensor [id T] ''
| |Elemwise{minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Subtensor{:int64:} [id U] '' |Subtensor{:int64:} [id U] ''
| |Subtensor{int64::} [id L] '' | |Subtensor{int64::} [id L] ''
| |ScalarFromTensor [id V] '' | |ScalarFromTensor [id V] ''
| |Elemwise{minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Elemwise{minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
|x [id W] |x [id W]
Inner graphs of the scan ops: Inner graphs of the scan ops:
...@@ -404,7 +404,7 @@ def test_scan_debugprint3(): ...@@ -404,7 +404,7 @@ def test_scan_debugprint3():
expected_output = """Sum{acc_dtype=float64} [id A] '' expected_output = """Sum{acc_dtype=float64} [id A] ''
|for{cpu,scan_fn} [id B] '' |for{cpu,scan_fn} [id B] ''
|Elemwise{minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
| |Subtensor{int64} [id D] '' | |Subtensor{int64} [id D] ''
| | |Shape [id E] '' | | |Shape [id E] ''
| | | |Subtensor{int64::} [id F] 'coefficients[0:]' | | | |Subtensor{int64::} [id F] 'coefficients[0:]'
...@@ -423,12 +423,12 @@ def test_scan_debugprint3(): ...@@ -423,12 +423,12 @@ def test_scan_debugprint3():
|Subtensor{:int64:} [id S] '' |Subtensor{:int64:} [id S] ''
| |Subtensor{int64::} [id F] 'coefficients[0:]' | |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] '' | |ScalarFromTensor [id T] ''
| |Elemwise{minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Subtensor{:int64:} [id U] '' |Subtensor{:int64:} [id U] ''
| |Subtensor{int64::} [id L] '' | |Subtensor{int64::} [id L] ''
| |ScalarFromTensor [id V] '' | |ScalarFromTensor [id V] ''
| |Elemwise{minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C] ''
|Elemwise{minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
|A [id W] |A [id W]
|k [id X] |k [id X]
......
...@@ -1532,8 +1532,8 @@ class ProfileStats: ...@@ -1532,8 +1532,8 @@ class ProfileStats:
scal.XOR, scal.XOR,
scal.AND, scal.AND,
scal.Invert, scal.Invert,
scal.Maximum, scal.ScalarMaximum,
scal.Minimum, scal.ScalarMinimum,
scal.Add, scal.Add,
scal.Mul, scal.Mul,
scal.Sub, scal.Sub,
......
...@@ -738,9 +738,9 @@ def local_dnn_reduction(fgraph, node): ...@@ -738,9 +738,9 @@ def local_dnn_reduction(fgraph, node):
scal = "norm1" scal = "norm1"
else: else:
return return
elif isinstance(node.op.scalar_op, theano.scalar.basic.Maximum) and isinstance( elif isinstance(
node.op.pre_scalar_op, theano.scalar.basic.Abs node.op.scalar_op, theano.scalar.basic.ScalarMaximum
): ) and isinstance(node.op.pre_scalar_op, theano.scalar.basic.Abs):
scal = "absmax" scal = "absmax"
else: else:
return return
......
...@@ -703,7 +703,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype): ...@@ -703,7 +703,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
# It might be nice to use a property of the op class to do this, # It might be nice to use a property of the op class to do this,
# but tensor.elemwise.CAReduce has this exact same check so I guess # but tensor.elemwise.CAReduce has this exact same check so I guess
# this is OK to do # this is OK to do
if self.scalar_op in [scalar.minimum, scalar.maximum]: if self.scalar_op in [scalar.scalar_minimum, scalar.scalar_maximum]:
conds = [ conds = [
f"(PyGpuArray_DIMS({x})[{i}] == 0)" f"(PyGpuArray_DIMS({x})[{i}] == 0)"
for i in range(nd_in) for i in range(nd_in)
...@@ -1083,7 +1083,9 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype): ...@@ -1083,7 +1083,9 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
if hasattr(self.scalar_op, "identity"): if hasattr(self.scalar_op, "identity"):
return str(self.scalar_op.identity) return str(self.scalar_op.identity)
else: else:
assert isinstance(self.scalar_op, (scalar.Maximum, scalar.Minimum)) assert isinstance(
self.scalar_op, (scalar.ScalarMaximum, scalar.ScalarMinimum)
)
if self.pre_scalar_op: # TODO: multiple dtypes if self.pre_scalar_op: # TODO: multiple dtypes
# dtype = node.inputs[0].dtype # dtype = node.inputs[0].dtype
......
...@@ -1205,7 +1205,8 @@ def local_gpu_extract_diag(fgraph, op, context_name, inputs, outputs): ...@@ -1205,7 +1205,8 @@ def local_gpu_extract_diag(fgraph, op, context_name, inputs, outputs):
@register_opt2([tensor.CAReduce, tensor.Sum, tensor.elemwise.Prod], "fast_compile") @register_opt2([tensor.CAReduce, tensor.Sum, tensor.elemwise.Prod], "fast_compile")
def local_gpua_careduce(fgraph, op, context_name, inputs, outputs): def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
if isinstance( if isinstance(
op.scalar_op, (scalar.Add, scalar.Mul, scalar.Maximum, scalar.Minimum) op.scalar_op,
(scalar.Add, scalar.Mul, scalar.ScalarMaximum, scalar.ScalarMinimum),
): ):
ctx = get_context(context_name) ctx = get_context(context_name)
......
...@@ -1729,7 +1729,7 @@ invert = Invert() ...@@ -1729,7 +1729,7 @@ invert = Invert()
############## ##############
# Arithmetic # Arithmetic
############## ##############
class Maximum(BinaryScalarOp): class ScalarMaximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("maximum", 2, 1) nfunc_spec = ("maximum", 2, 1)
...@@ -1768,10 +1768,10 @@ class Maximum(BinaryScalarOp): ...@@ -1768,10 +1768,10 @@ class Maximum(BinaryScalarOp):
return (gx, gy) return (gx, gy)
maximum = Maximum(upcast_out, name="maximum") scalar_maximum = ScalarMaximum(upcast_out, name="maximum")
class Minimum(BinaryScalarOp): class ScalarMinimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("minimum", 2, 1) nfunc_spec = ("minimum", 2, 1)
...@@ -1809,7 +1809,7 @@ class Minimum(BinaryScalarOp): ...@@ -1809,7 +1809,7 @@ class Minimum(BinaryScalarOp):
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name="minimum") scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
class Add(ScalarOp): class Add(ScalarOp):
......
...@@ -380,7 +380,7 @@ def numpy_scalar(data): ...@@ -380,7 +380,7 @@ def numpy_scalar(data):
) )
get_scalar_constant_value_elemwises = ( _scalar_constant_value_elemwise_ops = (
scal.Cast, scal.Cast,
scal.Switch, scal.Switch,
scal.NEQ, scal.NEQ,
...@@ -395,8 +395,8 @@ get_scalar_constant_value_elemwises = ( ...@@ -395,8 +395,8 @@ get_scalar_constant_value_elemwises = (
scal.Mul, scal.Mul,
scal.IntDiv, scal.IntDiv,
scal.TrueDiv, scal.TrueDiv,
scal.Minimum, scal.ScalarMinimum,
scal.Maximum, scal.ScalarMaximum,
) )
...@@ -502,7 +502,7 @@ def get_scalar_constant_value( ...@@ -502,7 +502,7 @@ def get_scalar_constant_value(
shp, val = v.owner.inputs shp, val = v.owner.inputs
v = val v = val
continue continue
if isinstance(v.owner.op, get_scalar_constant_value_elemwises): if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [ const = [
get_scalar_constant_value(i, max_recur=max_recur) get_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
...@@ -520,7 +520,7 @@ def get_scalar_constant_value( ...@@ -520,7 +520,7 @@ def get_scalar_constant_value(
v = val v = val
continue continue
elif elemwise and isinstance( elif elemwise and isinstance(
v.owner.op.scalar_op, get_scalar_constant_value_elemwises v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
): ):
const = [ const = [
get_scalar_constant_value(i, max_recur=max_recur) get_scalar_constant_value(i, max_recur=max_recur)
...@@ -1042,7 +1042,7 @@ elemwise.TensorConstant = TensorConstant ...@@ -1042,7 +1042,7 @@ elemwise.TensorConstant = TensorConstant
######################### #########################
def _scal_elemwise_with_nfunc(nfunc, nin, nout): def _scal_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
""" """
Replace a symbol definition with an elementwise version of the Replace a symbol definition with an elementwise version of the
corresponding scalar Op. If it is not None, the nfunc argument corresponding scalar Op. If it is not None, the nfunc argument
...@@ -1056,46 +1056,45 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout): ...@@ -1056,46 +1056,45 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
""" """
def construct(symbol): def construct(symbol):
symbolname = symbol.__name__ nonlocal symbolname
inplace = symbolname.endswith("_inplace")
if inplace:
msg = "inplace"
else:
msg = "no_inplace"
n = f"Elemwise{{{symbolname},{msg}}}" symbolname = symbolname or symbol.__name__
if inplace: if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scal, symbolname[: -len("_inplace")]) scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise( rval = elemwise.Elemwise(
inplace_scalar_op, inplace_scalar_op,
{0: 0}, {0: 0},
name=n, name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)), nfunc_spec=(nfunc and (nfunc, nin, nout)),
) )
else: else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scal, symbolname) scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise( rval = elemwise.Elemwise(
scalar_op, name=n, nfunc_spec=(nfunc and (nfunc, nin, nout)) scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
) )
if getattr(symbol, "__doc__", False): if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__ rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script # for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object # it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = "tensor" rval.__module__ = symbol.__module__
pprint.assign(rval, printing.FunctionPrinter(symbolname)) pprint.assign(
rval, printing.FunctionPrinter(symbolname.replace("_inplace", "="))
)
return rval return rval
return construct if symbol:
return construct(symbol[0])
else:
_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None) return construct
def _pack(x): def _pack(x):
...@@ -1772,14 +1771,14 @@ class Max(CAReduce): ...@@ -1772,14 +1771,14 @@ class Max(CAReduce):
nfunc_spec = ("max", 1, 1) nfunc_spec = ("max", 1, 1)
def __init__(self, axis): def __init__(self, axis):
super().__init__(scal.maximum, axis) super().__init__(scal.scalar_maximum, axis)
class Min(CAReduce): class Min(CAReduce):
nfunc_spec = ("min", 1, 1) nfunc_spec = ("min", 1, 1)
def __init__(self, axis): def __init__(self, axis):
super().__init__(scal.minimum, axis) super().__init__(scal.scalar_minimum, axis)
@constructor @constructor
...@@ -3661,13 +3660,13 @@ setdefault = default # legacy ...@@ -3661,13 +3660,13 @@ setdefault = default # legacy
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
@_scal_elemwise @_scal_elemwise(symbolname="scalar_maximum")
def maximum(x, y): def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor""" """elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise(symbolname="scalar_minimum")
def minimum(x, y): def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor""" """elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body # see decorator for function body
......
...@@ -1348,9 +1348,9 @@ class CAReduce(COp): ...@@ -1348,9 +1348,9 @@ class CAReduce(COp):
self.ufunc = np.add self.ufunc = np.add
elif isinstance(scalar_op, theano.scalar.basic.Mul): elif isinstance(scalar_op, theano.scalar.basic.Mul):
self.ufunc = np.multiply self.ufunc = np.multiply
elif isinstance(scalar_op, theano.scalar.basic.Maximum): elif isinstance(scalar_op, theano.scalar.basic.ScalarMaximum):
self.ufunc = np.maximum self.ufunc = np.maximum
elif isinstance(scalar_op, theano.scalar.basic.Minimum): elif isinstance(scalar_op, theano.scalar.basic.ScalarMinimum):
self.ufunc = np.minimum self.ufunc = np.minimum
elif isinstance(scalar_op, theano.scalar.basic.AND) and _numpy_ver >= [1, 12]: elif isinstance(scalar_op, theano.scalar.basic.AND) and _numpy_ver >= [1, 12]:
# numpy.bitwise_and.identity was incorrect for versions before # numpy.bitwise_and.identity was incorrect for versions before
...@@ -1570,8 +1570,8 @@ class CAReduce(COp): ...@@ -1570,8 +1570,8 @@ class CAReduce(COp):
if hasattr(self.scalar_op, "identity"): if hasattr(self.scalar_op, "identity"):
identity = self.scalar_op.identity identity = self.scalar_op.identity
elif self.scalar_op in [scalar.maximum, scalar.minimum]: elif self.scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]:
if self.scalar_op == scalar.maximum: if self.scalar_op == scalar.scalar_maximum:
scal_name = "maximum" scal_name = "maximum"
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()" identity = "-__builtin_inf()"
...@@ -1580,7 +1580,7 @@ class CAReduce(COp): ...@@ -1580,7 +1580,7 @@ class CAReduce(COp):
identity = "0" identity = "0"
else: else:
identity = "NPY_MIN_" + str(input.type.dtype).upper() identity = "NPY_MIN_" + str(input.type.dtype).upper()
if self.scalar_op == scalar.minimum: if self.scalar_op == scalar.scalar_minimum:
scal_name = "minimum" scal_name = "minimum"
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()" identity = "__builtin_inf()"
......
from theano import printing from theano import printing
from theano import scalar as scal
from theano.printing import pprint from theano.printing import pprint
from theano.tensor import elemwise
from theano.tensor.basic import _scal_elemwise
from . import elemwise
@_scal_elemwise
def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__
inplace = symbolname.endswith("_inplace")
if inplace:
scalar_op = getattr(scal, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=symbolname)
if getattr(symbol, "__doc__", False):
rval.__doc__ = symbol.__doc__ + "\n" + rval.__doc__
# for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = "theano.tensor.inplace"
pprint.assign(rval, printing.FunctionPrinter(symbolname.replace("_inplace", "=")))
return rval
@_scal_inplace
def lt_inplace(a, b): def lt_inplace(a, b):
"""a < b (inplace on a)""" """a < b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def gt_inplace(a, b): def gt_inplace(a, b):
"""a > b (inplace on a)""" """a > b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def le_inplace(a, b): def le_inplace(a, b):
"""a <= b (inplace on a)""" """a <= b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def ge_inplace(a, b): def ge_inplace(a, b):
"""a >= b (inplace on a)""" """a >= b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def eq_inplace(a, b): def eq_inplace(a, b):
"""a == b (inplace on a)""" """a == b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def neq_inplace(a, b): def neq_inplace(a, b):
"""a != b (inplace on a)""" """a != b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def and__inplace(a, b): def and__inplace(a, b):
"""bitwise a & b (inplace on a)""" """bitwise a & b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def or__inplace(a, b): def or__inplace(a, b):
"""bitwise a | b (inplace on a)""" """bitwise a | b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def xor_inplace(a, b): def xor_inplace(a, b):
"""bitwise a ^ b (inplace on a)""" """bitwise a ^ b (inplace on a)"""
@_scal_inplace @_scal_elemwise
def invert_inplace(a): def invert_inplace(a):
"""bitwise ~a (inplace on a)""" """bitwise ~a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def abs__inplace(a): def abs__inplace(a):
"""|`a`| (inplace on `a`)""" """|`a`| (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def exp_inplace(a): def exp_inplace(a):
"""e^`a` (inplace on `a`)""" """e^`a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def exp2_inplace(a): def exp2_inplace(a):
"""2^`a` (inplace on `a`)""" """2^`a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def expm1_inplace(a): def expm1_inplace(a):
"""e^`a` - 1 (inplace on `a`)""" """e^`a` - 1 (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def neg_inplace(a): def neg_inplace(a):
"""-a (inplace on a)""" """-a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def inv_inplace(a): def inv_inplace(a):
"""1.0/a (inplace on a)""" """1.0/a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def log_inplace(a): def log_inplace(a):
"""base e logarithm of a (inplace on a)""" """base e logarithm of a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def log1p_inplace(a): def log1p_inplace(a):
"""log(1+a)""" """log(1+a)"""
@_scal_inplace @_scal_elemwise
def log2_inplace(a): def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)""" """base 2 logarithm of a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def log10_inplace(a): def log10_inplace(a):
"""base 10 logarithm of a (inplace on a)""" """base 10 logarithm of a (inplace on a)"""
@_scal_inplace @_scal_elemwise
def sgn_inplace(a): def sgn_inplace(a):
"""sign of `a` (inplace on `a`)""" """sign of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def ceil_inplace(a): def ceil_inplace(a):
"""ceil of `a` (inplace on `a`)""" """ceil of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def floor_inplace(a): def floor_inplace(a):
"""floor of `a` (inplace on `a`)""" """floor of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def trunc_inplace(a): def trunc_inplace(a):
"""trunc of `a` (inplace on `a`)""" """trunc of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def round_half_to_even_inplace(a): def round_half_to_even_inplace(a):
"""round_half_to_even_inplace(a) (inplace on `a`)""" """round_half_to_even_inplace(a) (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def round_half_away_from_zero_inplace(a): def round_half_away_from_zero_inplace(a):
"""round_half_away_from_zero_inplace(a) (inplace on `a`)""" """round_half_away_from_zero_inplace(a) (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def sqr_inplace(a): def sqr_inplace(a):
"""square of `a` (inplace on `a`)""" """square of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def sqrt_inplace(a): def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)""" """square root of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def deg2rad_inplace(a): def deg2rad_inplace(a):
"""convert degree `a` to radian(inplace on `a`)""" """convert degree `a` to radian(inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def rad2deg_inplace(a): def rad2deg_inplace(a):
"""convert radian `a` to degree(inplace on `a`)""" """convert radian `a` to degree(inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def cos_inplace(a): def cos_inplace(a):
"""cosine of `a` (inplace on `a`)""" """cosine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arccos_inplace(a): def arccos_inplace(a):
"""arccosine of `a` (inplace on `a`)""" """arccosine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def sin_inplace(a): def sin_inplace(a):
"""sine of `a` (inplace on `a`)""" """sine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arcsin_inplace(a): def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)""" """arcsine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def tan_inplace(a): def tan_inplace(a):
"""tangent of `a` (inplace on `a`)""" """tangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arctan_inplace(a): def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)""" """arctangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arctan2_inplace(a, b): def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)""" """arctangent of `a` / `b` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def cosh_inplace(a): def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)""" """hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arccosh_inplace(a): def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)""" """hyperbolic arc cosine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def sinh_inplace(a): def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)""" """hyperbolic sine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arcsinh_inplace(a): def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)""" """hyperbolic arc sine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def tanh_inplace(a): def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)""" """hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def arctanh_inplace(a): def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)""" """hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def erf_inplace(a): def erf_inplace(a):
"""error function""" """error function"""
@_scal_inplace @_scal_elemwise
def erfc_inplace(a): def erfc_inplace(a):
"""complementary error function""" """complementary error function"""
@_scal_inplace @_scal_elemwise
def erfcx_inplace(a): def erfcx_inplace(a):
"""scaled complementary error function""" """scaled complementary error function"""
@_scal_inplace @_scal_elemwise
def gamma_inplace(a): def gamma_inplace(a):
"""gamma function""" """gamma function"""
@_scal_inplace @_scal_elemwise
def gammaln_inplace(a): def gammaln_inplace(a):
"""log gamma function""" """log gamma function"""
@_scal_inplace @_scal_elemwise
def psi_inplace(a): def psi_inplace(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_inplace @_scal_elemwise
def tri_gamma_inplace(a): def tri_gamma_inplace(a):
"""second derivative of the log gamma function""" """second derivative of the log gamma function"""
@_scal_inplace @_scal_elemwise
def chi2sf_inplace(x, k): def chi2sf_inplace(x, k):
"""chi squared survival function""" """chi squared survival function"""
@_scal_inplace @_scal_elemwise
def j0_inplace(x): def j0_inplace(x):
"""Bessel function of the first kind of order 0.""" """Bessel function of the first kind of order 0."""
@_scal_inplace @_scal_elemwise
def j1_inplace(x): def j1_inplace(x):
"""Bessel function of the first kind of order 1.""" """Bessel function of the first kind of order 1."""
@_scal_inplace @_scal_elemwise
def jv_inplace(v, x): def jv_inplace(v, x):
"""Bessel function of the first kind of order v (real).""" """Bessel function of the first kind of order v (real)."""
@_scal_inplace @_scal_elemwise
def i0_inplace(x): def i0_inplace(x):
"""Modified Bessel function of the first kind of order 0.""" """Modified Bessel function of the first kind of order 0."""
@_scal_inplace @_scal_elemwise
def i1_inplace(x): def i1_inplace(x):
"""Modified Bessel function of the first kind of order 1.""" """Modified Bessel function of the first kind of order 1."""
@_scal_inplace @_scal_elemwise
def iv_inplace(v, x): def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real).""" """Modified Bessel function of the first kind of order v (real)."""
@_scal_inplace @_scal_elemwise
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
...@@ -324,52 +298,52 @@ fill_inplace = second_inplace ...@@ -324,52 +298,52 @@ fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter("fill=")) pprint.assign(fill_inplace, printing.FunctionPrinter("fill="))
@_scal_inplace @_scal_elemwise(symbolname="scalar_maximum_inplace")
def maximum_inplace(a, b): def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)""" """elementwise addition (inplace on `a`)"""
@_scal_inplace @_scal_elemwise(symbolname="scalar_minimum_inplace")
def minimum_inplace(a, b): def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)""" """elementwise addition (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def add_inplace(a, b): def add_inplace(a, b):
"""elementwise addition (inplace on `a`)""" """elementwise addition (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def sub_inplace(a, b): def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)""" """elementwise subtraction (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def mul_inplace(a, b): def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)""" """elementwise multiplication (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def true_div_inplace(a, b): def true_div_inplace(a, b):
"""elementwise division (inplace on `a`)""" """elementwise division (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def int_div_inplace(a, b): def int_div_inplace(a, b):
"""elementwise division (inplace on `a`)""" """elementwise division (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def mod_inplace(a, b): def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)""" """elementwise modulo (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def pow_inplace(a, b): def pow_inplace(a, b):
"""elementwise power (inplace on `a`)""" """elementwise power (inplace on `a`)"""
@_scal_inplace @_scal_elemwise
def conj_inplace(a): def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)""" """elementwise conjugate (inplace on `a`)"""
......
...@@ -5624,7 +5624,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -5624,7 +5624,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [res] return [res]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if ( if (
isinstance(node.op.scalar_op, (ts.Minimum, ts.Maximum)) isinstance(node.op.scalar_op, (ts.ScalarMinimum, ts.ScalarMaximum))
and node.inputs[0] is node.inputs[1] and node.inputs[0] is node.inputs[1]
): ):
res = node.inputs[0] res = node.inputs[0]
...@@ -5656,7 +5656,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -5656,7 +5656,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [res] return [res]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if ( if (
isinstance(node.op.scalar_op, ts.Maximum) isinstance(node.op.scalar_op, ts.ScalarMaximum)
and node.inputs[0].owner and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i) and isinstance(node.inputs[0].owner.op, Shape_i)
and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0 and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0
...@@ -5665,7 +5665,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -5665,7 +5665,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [node.inputs[0]] return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i] # Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if ( if (
isinstance(node.op.scalar_op, ts.Maximum) isinstance(node.op.scalar_op, ts.ScalarMaximum)
and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0 and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i) and isinstance(node.inputs[1].owner.op, Shape_i)
...@@ -5674,7 +5674,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -5674,7 +5674,7 @@ def local_useless_elemwise_comparison(fgraph, node):
return [node.inputs[1]] return [node.inputs[1]]
# Elemwise[minimum](X.shape[i], 0) -> 0 # Elemwise[minimum](X.shape[i], 0) -> 0
if ( if (
isinstance(node.op.scalar_op, ts.Minimum) isinstance(node.op.scalar_op, ts.ScalarMinimum)
and node.inputs[0].owner and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i) and isinstance(node.inputs[0].owner.op, Shape_i)
and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0 and tt.extract_constant(node.inputs[1], only_process_constants=True) == 0
...@@ -5686,7 +5686,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -5686,7 +5686,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if ( if (
isinstance(node.op.scalar_op, ts.Minimum) isinstance(node.op.scalar_op, ts.ScalarMinimum)
and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0 and tt.extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i) and isinstance(node.inputs[1].owner.op, Shape_i)
...@@ -6039,7 +6039,7 @@ def local_reduce_join(fgraph, node): ...@@ -6039,7 +6039,7 @@ def local_reduce_join(fgraph, node):
if tt.extract_constant(join.inputs[0], only_process_constants=True) != 0: if tt.extract_constant(join.inputs[0], only_process_constants=True) != 0:
return return
if isinstance(node.op.scalar_op, (ts.Maximum, ts.Minimum)): if isinstance(node.op.scalar_op, (ts.ScalarMaximum, ts.ScalarMinimum)):
# Support only 2 inputs for now # Support only 2 inputs for now
if len(join.inputs) != 3: if len(join.inputs) != 3:
return return
......
...@@ -82,7 +82,7 @@ def local_max_to_min(fgraph, node): ...@@ -82,7 +82,7 @@ def local_max_to_min(fgraph, node):
if ( if (
max.owner max.owner
and isinstance(max.owner.op, CAReduce) and isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == scal.maximum and max.owner.op.scalar_op == scal.scalar_maximum
): ):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == tt.neg: if neg.owner and neg.owner.op == tt.neg:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论