提交 80d8a0ef authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Attach NumPy information to Elemwise operations

上级 a62dec23
...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout] return [out, argout]
class Max(CAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(scal.maximum, axis)
class Min(CAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(scal.minimum, axis)
@constructor @constructor
def max(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False):
""" """
...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False): ...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False):
try: try:
out = max_and_argmax(x, axis)[0] out = max_and_argmax(x, axis)[0]
except Exception: except Exception:
out = CAReduce(scal.maximum, axis)(x) out = Max(axis)(x)
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
...@@ -3416,7 +3430,7 @@ def prod( ...@@ -3416,7 +3430,7 @@ def prod(
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis) super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
def __str__(self): def __str__(self):
...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce): ...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce):
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
if self.axis is not None: if self.axis is not None:
return super(Op, self).c_code(node, name, inames, onames, sub) return super(Op, self).c_code(node, name, inames, onames, sub)
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub) ret = super().c_code(self, node, name, inames, onames, sub)
# TODO: c_code perform support only axis is None # TODO: c_code perform support only axis is None
return ( return (
ret ret
......
...@@ -1761,6 +1761,7 @@ class All(CAReduce): ...@@ -1761,6 +1761,7 @@ class All(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) CAReduce.__init__(self, scalar.and_, axis)
...@@ -1793,6 +1794,7 @@ class Any(CAReduce): ...@@ -1793,6 +1794,7 @@ class Any(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) CAReduce.__init__(self, scalar.or_, axis)
...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype): ...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__( CAReduceDtype.__init__(
...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype): ...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__( CAReduceDtype.__init__(
......
...@@ -31,44 +31,40 @@ supposed to be canonical. ...@@ -31,44 +31,40 @@ supposed to be canonical.
""" """
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging import logging
from theano import gof import theano.tensor.basic as tt
from theano.tensor.elemwise import CAReduce import theano.scalar.basic as scal
from theano.tensor import basic as T
from theano.tensor import DimShuffle, Subtensor
from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
from theano.gof.opt import copy_stack_trace
_logger = logging.getLogger("theano.tensor.opt") _logger = logging.getLogger("theano.tensor.opt")
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.MaxAndArgmax]) @local_optimizer([tt.MaxAndArgmax])
def local_max_and_argmax(node): def local_max_and_argmax(node):
""" """
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
""" """
if isinstance(node.op, T.MaxAndArgmax): if isinstance(node.op, tt.MaxAndArgmax):
axis = node.op.get_params(node) axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = tt.Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
new = T.Argmax(axis)(node.inputs[0]) new = tt.Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [None, new] return [None, new]
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.neg]) @local_optimizer([tt.neg])
def local_max_to_min(node): def local_max_to_min(node):
""" """
Change -(max(-x)) to min. Change -(max(-x)) to min.
...@@ -81,7 +77,7 @@ def local_max_to_min(node): ...@@ -81,7 +77,7 @@ def local_max_to_min(node):
the interface put only MaxAndArgmax into the graph. the interface put only MaxAndArgmax into the graph.
""" """
if node.op == T.neg and node.inputs[0].owner: if node.op == tt.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if ( if (
max.owner max.owner
...@@ -89,15 +85,15 @@ def local_max_to_min(node): ...@@ -89,15 +85,15 @@ def local_max_to_min(node):
and max.owner.op.scalar_op == scal.maximum and max.owner.op.scalar_op == scal.maximum
): ):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == tt.neg:
new = CAReduce(scal.minimum, max.owner.op.axis)(neg.owner.inputs[0]) new = tt.Min(max.owner.op.axis)(neg.owner.inputs[0])
return [copy_stack_trace(node.outputs[0], new)] return [copy_stack_trace(node.outputs[0], new)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Alloc]) @local_optimizer([tt.Alloc])
def local_alloc_dimshuffle(node): def local_alloc_dimshuffle(node):
""" """
If a dimshuffle is inside an alloc and only adds dimension to the If a dimshuffle is inside an alloc and only adds dimension to the
...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node): ...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node):
Alloc(DimShuffle(x), ...) - > Alloc(x, ...) Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
""" """
if isinstance(node.op, T.Alloc): if isinstance(node.op, tt.Alloc):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node): ...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node):
) + tuple(range(input_.owner.inputs[0].ndim)) ) + tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order: if new_order != expected_new_order:
return False return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])] return [tt.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Reshape]) @local_optimizer([tt.Reshape])
def local_reshape_dimshuffle(node): def local_reshape_dimshuffle(node):
""" """
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node): ...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node):
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp) Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
""" """
if isinstance(node.op, T.Reshape): if isinstance(node.op, tt.Reshape):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node): ...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node):
else: else:
offset += 1 offset += 1
return [ return [
T.reshape( tt.reshape(
input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim
) )
] ]
...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node): ...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node): def local_dimshuffle_alloc(node):
""" """
If an alloc is inside a dimshuffle which only adds dimension to the left, If an alloc is inside a dimshuffle which only adds dimension to the left,
...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node): ...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node):
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc): if isinstance(input_.owner.op, tt.Alloc):
# check if it only adds dimension to the left # check if it only adds dimension to the left
new_order = node.op.new_order new_order = node.op.new_order
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple(
...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node): ...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node):
nb_new_dims = len(new_order) - input_.ndim nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)] return [tt.alloc(input_.owner.inputs[0], *new_shape_input)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node): def local_dimshuffle_subtensor(node):
"""If a subtensor is inside a dimshuffle which only drop """If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the broadcastable dimensions, scrap the dimshuffle and index the
...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node): ...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node):
# tensor was indexed such as x[scalar, :, :], check that as well # tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]] new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0) zero = tt.constant(0)
slice_attr_list = ["start", "stop", "step"] slice_attr_list = ["start", "stop", "step"]
j = 0 j = 0
slice_i = -1 slice_i = -1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论