提交 80d8a0ef authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Attach NumPy information to Elemwise operations

上级 a62dec23
......@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout]
class Max(CAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(scal.maximum, axis)
class Min(CAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(scal.minimum, axis)
@constructor
def max(x, axis=None, keepdims=False):
"""
......@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False):
try:
out = max_and_argmax(x, axis)[0]
except Exception:
out = CAReduce(scal.maximum, axis)(x)
out = Max(axis)(x)
if keepdims:
out = makeKeepDims(x, out, axis)
......@@ -3416,7 +3430,7 @@ def prod(
class Mean(elemwise.CAReduce):
def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis)
super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1
def __str__(self):
......@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce):
def c_code(self, node, name, inames, onames, sub):
if self.axis is not None:
return super(Op, self).c_code(node, name, inames, onames, sub)
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub)
ret = super().c_code(self, node, name, inames, onames, sub)
# TODO: c_code perform support only axis is None
return (
ret
......
......@@ -1761,6 +1761,7 @@ class All(CAReduce):
"""
__props__ = ("axis",)
nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis)
......@@ -1793,6 +1794,7 @@ class Any(CAReduce):
"""
__props__ = ("axis",)
nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis)
......@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype):
"""
__props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(
......@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype):
"""
__props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__(
......
......@@ -31,44 +31,40 @@ supposed to be canonical.
"""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging
from theano import gof
from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T
from theano.tensor import DimShuffle, Subtensor
import theano.tensor.basic as tt
import theano.scalar.basic as scal
from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
from theano.gof.opt import copy_stack_trace
_logger = logging.getLogger("theano.tensor.opt")
@register_uncanonicalize
@gof.local_optimizer([T.MaxAndArgmax])
@local_optimizer([tt.MaxAndArgmax])
def local_max_and_argmax(node):
"""
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, T.MaxAndArgmax):
if isinstance(node.op, tt.MaxAndArgmax):
axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0:
new = CAReduce(scal.maximum, axis)(node.inputs[0])
new = tt.Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None]
if len(node.outputs[0].clients) == 0:
new = T.Argmax(axis)(node.inputs[0])
new = tt.Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize
@gof.local_optimizer([T.neg])
@local_optimizer([tt.neg])
def local_max_to_min(node):
"""
Change -(max(-x)) to min.
......@@ -81,7 +77,7 @@ def local_max_to_min(node):
the interface put only MaxAndArgmax into the graph.
"""
if node.op == T.neg and node.inputs[0].owner:
if node.op == tt.neg and node.inputs[0].owner:
max = node.inputs[0]
if (
max.owner
......@@ -89,15 +85,15 @@ def local_max_to_min(node):
and max.owner.op.scalar_op == scal.maximum
):
neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg:
new = CAReduce(scal.minimum, max.owner.op.axis)(neg.owner.inputs[0])
if neg.owner and neg.owner.op == tt.neg:
new = tt.Min(max.owner.op.axis)(neg.owner.inputs[0])
return [copy_stack_trace(node.outputs[0], new)]
return False
@register_uncanonicalize
@gof.local_optimizer([T.Alloc])
@local_optimizer([tt.Alloc])
def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
......@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node):
Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
"""
if isinstance(node.op, T.Alloc):
if isinstance(node.op, tt.Alloc):
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
......@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node):
) + tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order:
return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return [tt.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False
@register_uncanonicalize
@gof.local_optimizer([T.Reshape])
@local_optimizer([tt.Reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
......@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node):
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
"""
if isinstance(node.op, T.Reshape):
if isinstance(node.op, tt.Reshape):
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order
......@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node):
else:
offset += 1
return [
T.reshape(
tt.reshape(
input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim
)
]
......@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize
@gof.local_optimizer([DimShuffle])
@local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
......@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node):
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc):
if isinstance(input_.owner.op, tt.Alloc):
# check if it only adds dimension to the left
new_order = node.op.new_order
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple(
......@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node):
nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)]
return [tt.alloc(input_.owner.inputs[0], *new_shape_input)]
return False
@register_uncanonicalize
@gof.local_optimizer([DimShuffle])
@local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node):
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
......@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node):
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0)
zero = tt.constant(0)
slice_attr_list = ["start", "stop", "step"]
j = 0
slice_i = -1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论