提交 94c2e4c2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.elemwise

上级 f1dc0897
...@@ -62,17 +62,17 @@ class DimShuffle(ExternalCOp): ...@@ -62,17 +62,17 @@ class DimShuffle(ExternalCOp):
If `j = new_order[i]` is an index, the output's ith dimension If `j = new_order[i]` is an index, the output's ith dimension
will be the input's jth dimension. will be the input's jth dimension.
If `new_order[i]` is `x`, the output's ith dimension will If `new_order[i]` is `x`, the output's ith dimension will
be 1 and Broadcast operations will be allowed to do broadcasting be 1 and broadcast operations will be allowed to do broadcasting
over that dimension. over that dimension.
If `input.broadcastable[i] == False` then `i` must be found in new_order. If `input.type.shape[i] != 1` then `i` must be found in `new_order`.
Broadcastable dimensions, on the other hand, can be discarded. Broadcastable dimensions, on the other hand, can be discarded.
.. code-block:: python .. code-block:: python
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1]) DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
This op will only work on 3d tensors with no broadcastable This `Op` will only work on 3d tensors with no broadcastable
dimensions. The first dimension will be broadcastable, dimensions. The first dimension will be broadcastable,
then we will have the third dimension of the input tensor as then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has the second of the resulting tensor, etc. If the tensor has
...@@ -83,7 +83,7 @@ class DimShuffle(ExternalCOp): ...@@ -83,7 +83,7 @@ class DimShuffle(ExternalCOp):
DimShuffle((True, False), [1]) DimShuffle((True, False), [1])
This op will only work on 2d tensors with the first dimension This `Op` will only work on 2d tensors with the first dimension
broadcastable. broadcastable.
The second dimension of the input tensor will be the first dimension of The second dimension of the input tensor will be the first dimension of
the resulting tensor. the resulting tensor.
...@@ -186,7 +186,7 @@ class DimShuffle(ExternalCOp): ...@@ -186,7 +186,7 @@ class DimShuffle(ExternalCOp):
def make_node(self, _input): def make_node(self, _input):
input = as_tensor_variable(_input) input = as_tensor_variable(_input)
ib = tuple(input.type.broadcastable) ib = tuple(s == 1 for s in input.type.shape)
if ib != self.input_broadcastable: if ib != self.input_broadcastable:
if len(ib) != len(self.input_broadcastable): if len(ib) != len(self.input_broadcastable):
raise TypeError( raise TypeError(
...@@ -258,7 +258,7 @@ class DimShuffle(ExternalCOp): ...@@ -258,7 +258,7 @@ class DimShuffle(ExternalCOp):
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
grad_order = ["x"] * len(x.type.broadcastable) grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order): for i, v in enumerate(self.new_order):
if v != "x": if v != "x":
grad_order[v] = i grad_order[v] = i
...@@ -269,7 +269,7 @@ class DimShuffle(ExternalCOp): ...@@ -269,7 +269,7 @@ class DimShuffle(ExternalCOp):
return [inp[0].zeros_like(dtype=config.floatX)] return [inp[0].zeros_like(dtype=config.floatX)]
else: else:
return [ return [
DimShuffle(gz.type.broadcastable, grad_order)( DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
Elemwise(scalar_identity)(gz) Elemwise(scalar_identity)(gz)
) )
] ]
...@@ -406,7 +406,7 @@ class Elemwise(OpenMPOp): ...@@ -406,7 +406,7 @@ class Elemwise(OpenMPOp):
# TODO: use LComplete instead # TODO: use LComplete instead
args.append( args.append(
dim_shuffle( dim_shuffle(
input.type.broadcastable, tuple(1 if s == 1 else None for s in input.type.shape),
["x"] * difference + list(range(length)), ["x"] * difference + list(range(length)),
)(input) )(input)
) )
...@@ -452,11 +452,11 @@ class Elemwise(OpenMPOp): ...@@ -452,11 +452,11 @@ class Elemwise(OpenMPOp):
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items(): for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip( for out_s, in_s in zip(
out_shapes[overwriter], out_shapes[overwriter],
inputs[overwritten].type.broadcastable, inputs[overwritten].type.shape,
): ):
if ib and not ob == 1: if in_s == 1 and out_s != 1:
raise ValueError( raise ValueError(
"Operation cannot be done inplace on an input " "Operation cannot be done inplace on an input "
"with broadcasted dimensions." "with broadcasted dimensions."
...@@ -578,8 +578,8 @@ class Elemwise(OpenMPOp): ...@@ -578,8 +578,8 @@ class Elemwise(OpenMPOp):
# TODO: only count dimensions that were effectively broadcasted # TODO: only count dimensions that were effectively broadcasted
to_sum = [ to_sum = [
j j
for j, bcast in enumerate(ipt.type.broadcastable) for j, in_s in enumerate(ipt.type.shape)
if bcast and not outs[0].broadcastable[j] if in_s == 1 and outs[0].type.shape[j] != 1
] ]
if to_sum: if to_sum:
...@@ -614,7 +614,7 @@ class Elemwise(OpenMPOp): ...@@ -614,7 +614,7 @@ class Elemwise(OpenMPOp):
f"{str(self.scalar_op)}.grad returned {str(type(scalar_igrads))} instead of list or tuple" f"{str(self.scalar_op)}.grad returned {str(type(scalar_igrads))} instead of list or tuple"
) )
nd = len(inputs[0].type.broadcastable) # this is the same for everyone nd = inputs[0].type.ndim # this is the same for everyone
def transform(r): def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops. # From a graph of ScalarOps, make a graph of Broadcast ops.
...@@ -897,7 +897,7 @@ class Elemwise(OpenMPOp): ...@@ -897,7 +897,7 @@ class Elemwise(OpenMPOp):
# for each input: # for each input:
# same as range(ndim), but with 'x' at all broadcastable positions # same as range(ndim), but with 'x' at all broadcastable positions
orders = [ orders = [
[x and "x" or i for i, x in enumerate(input.type.broadcastable)] [s == 1 and "x" or i for i, s in enumerate(input.type.shape)]
for input in inputs for input in inputs
] ]
...@@ -920,7 +920,7 @@ class Elemwise(OpenMPOp): ...@@ -920,7 +920,7 @@ class Elemwise(OpenMPOp):
[ [
f"PyArray_ISFORTRAN({arr})" f"PyArray_ISFORTRAN({arr})"
for arr, var in z for arr, var in z
if not all(var.broadcastable) if not all(s == 1 for s in var.type.shape)
] ]
) )
# If it is a scalar, make it c contig to prevent problem with # If it is a scalar, make it c contig to prevent problem with
...@@ -1005,7 +1005,7 @@ class Elemwise(OpenMPOp): ...@@ -1005,7 +1005,7 @@ class Elemwise(OpenMPOp):
or or
# Use simpler code when output ndim == 0 or 1 # Use simpler code when output ndim == 0 or 1
# or for broadcated scalar. # or for broadcated scalar.
all(node.outputs[0].broadcastable) all(s == 1 for s in node.outputs[0].type.shape)
): ):
if nnested: if nnested:
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
...@@ -1077,7 +1077,7 @@ class Elemwise(OpenMPOp): ...@@ -1077,7 +1077,7 @@ class Elemwise(OpenMPOp):
all(o.ndim >= 1 for o in node.outputs) all(o.ndim >= 1 for o in node.outputs)
and and
# Don't use the contig code for broadcasted scalar. # Don't use the contig code for broadcasted scalar.
not all(node.outputs[0].broadcastable) not all(s == 1 for s in node.outputs[0].type.shape)
): ):
contig = None contig = None
try: try:
...@@ -1110,7 +1110,7 @@ class Elemwise(OpenMPOp): ...@@ -1110,7 +1110,7 @@ class Elemwise(OpenMPOp):
""" """
index = "" index = ""
for x, var in zip(inames + onames, inputs + node.outputs): for x, var in zip(inames + onames, inputs + node.outputs):
if not all(var.broadcastable): if not all(s == 1 for s in var.type.shape):
contig += ( contig += (
""" """
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s); dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
...@@ -1144,18 +1144,19 @@ class Elemwise(OpenMPOp): ...@@ -1144,18 +1144,19 @@ class Elemwise(OpenMPOp):
) )
if contig is not None: if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs)) z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
cond1 = " && ".join( cond1 = " && ".join(
[ [
"PyArray_ISCONTIGUOUS(%s)" % arr "PyArray_ISCONTIGUOUS(%s)" % arr
for arr, var in z for arr, var in z
if not all(var.broadcastable) if not all_broadcastable
] ]
) )
cond2 = " && ".join( cond2 = " && ".join(
[ [
"PyArray_ISFORTRAN(%s)" % arr "PyArray_ISFORTRAN(%s)" % arr
for arr, var in z for arr, var in z
if not all(var.broadcastable) if not all_broadcastable
] ]
) )
loop = ( loop = (
...@@ -1388,13 +1389,7 @@ class CAReduce(COp): ...@@ -1388,13 +1389,7 @@ class CAReduce(COp):
axis = self.axis axis = self.axis
if axis is None: if axis is None:
return ((),) return ((),)
return ( return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)
[
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in axis
],
)
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
...@@ -1419,7 +1414,7 @@ class CAReduce(COp): ...@@ -1419,7 +1414,7 @@ class CAReduce(COp):
axis = self.axis axis = self.axis
if axis is None: if axis is None:
axis = list(range(len(input.type.broadcastable))) axis = list(range(input.type.ndim))
if len(axis) == 0: if len(axis) == 0:
# The acc_dtype is never a downcast compared to the input dtype # The acc_dtype is never a downcast compared to the input dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论