提交 94c2e4c2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.elemwise

上级 f1dc0897
......@@ -62,17 +62,17 @@ class DimShuffle(ExternalCOp):
If `j = new_order[i]` is an index, the output's ith dimension
will be the input's jth dimension.
If `new_order[i]` is `x`, the output's ith dimension will
be 1 and Broadcast operations will be allowed to do broadcasting
be 1 and broadcast operations will be allowed to do broadcasting
over that dimension.
If `input.broadcastable[i] == False` then `i` must be found in new_order.
If `input.type.shape[i] != 1` then `i` must be found in `new_order`.
Broadcastable dimensions, on the other hand, can be discarded.
.. code-block:: python
DimShuffle((False, False, False), ['x', 2, 'x', 0, 1])
This op will only work on 3d tensors with no broadcastable
This `Op` will only work on 3d tensors with no broadcastable
dimensions. The first dimension will be broadcastable,
then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has
......@@ -83,7 +83,7 @@ class DimShuffle(ExternalCOp):
DimShuffle((True, False), [1])
This op will only work on 2d tensors with the first dimension
This `Op` will only work on 2d tensors with the first dimension
broadcastable.
The second dimension of the input tensor will be the first dimension of
the resulting tensor.
......@@ -186,7 +186,7 @@ class DimShuffle(ExternalCOp):
def make_node(self, _input):
input = as_tensor_variable(_input)
ib = tuple(input.type.broadcastable)
ib = tuple(s == 1 for s in input.type.shape)
if ib != self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
raise TypeError(
......@@ -258,7 +258,7 @@ class DimShuffle(ExternalCOp):
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
grad_order = ["x"] * len(x.type.broadcastable)
grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order):
if v != "x":
grad_order[v] = i
......@@ -269,7 +269,7 @@ class DimShuffle(ExternalCOp):
return [inp[0].zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(gz.type.broadcastable, grad_order)(
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
Elemwise(scalar_identity)(gz)
)
]
......@@ -406,7 +406,7 @@ class Elemwise(OpenMPOp):
# TODO: use LComplete instead
args.append(
dim_shuffle(
input.type.broadcastable,
tuple(1 if s == 1 else None for s in input.type.shape),
["x"] * difference + list(range(length)),
)(input)
)
......@@ -452,11 +452,11 @@ class Elemwise(OpenMPOp):
inplace_pattern = self.inplace_pattern
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip(
for out_s, in_s in zip(
out_shapes[overwriter],
inputs[overwritten].type.broadcastable,
inputs[overwritten].type.shape,
):
if ib and not ob == 1:
if in_s == 1 and out_s != 1:
raise ValueError(
"Operation cannot be done inplace on an input "
"with broadcasted dimensions."
......@@ -578,8 +578,8 @@ class Elemwise(OpenMPOp):
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
for j, in_s in enumerate(ipt.type.shape)
if in_s == 1 and outs[0].type.shape[j] != 1
]
if to_sum:
......@@ -614,7 +614,7 @@ class Elemwise(OpenMPOp):
f"{str(self.scalar_op)}.grad returned {str(type(scalar_igrads))} instead of list or tuple"
)
nd = len(inputs[0].type.broadcastable) # this is the same for everyone
nd = inputs[0].type.ndim # this is the same for everyone
def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops.
......@@ -897,7 +897,7 @@ class Elemwise(OpenMPOp):
# for each input:
# same as range(ndim), but with 'x' at all broadcastable positions
orders = [
[x and "x" or i for i, x in enumerate(input.type.broadcastable)]
[s == 1 and "x" or i for i, s in enumerate(input.type.shape)]
for input in inputs
]
......@@ -920,7 +920,7 @@ class Elemwise(OpenMPOp):
[
f"PyArray_ISFORTRAN({arr})"
for arr, var in z
if not all(var.broadcastable)
if not all(s == 1 for s in var.type.shape)
]
)
# If it is a scalar, make it c contig to prevent problem with
......@@ -1005,7 +1005,7 @@ class Elemwise(OpenMPOp):
or
# Use simpler code when output ndim == 0 or 1
# or for broadcated scalar.
all(node.outputs[0].broadcastable)
all(s == 1 for s in node.outputs[0].type.shape)
):
if nnested:
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
......@@ -1077,7 +1077,7 @@ class Elemwise(OpenMPOp):
all(o.ndim >= 1 for o in node.outputs)
and
# Don't use the contig code for broadcasted scalar.
not all(node.outputs[0].broadcastable)
not all(s == 1 for s in node.outputs[0].type.shape)
):
contig = None
try:
......@@ -1110,7 +1110,7 @@ class Elemwise(OpenMPOp):
"""
index = ""
for x, var in zip(inames + onames, inputs + node.outputs):
if not all(var.broadcastable):
if not all(s == 1 for s in var.type.shape):
contig += (
"""
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
......@@ -1144,18 +1144,19 @@ class Elemwise(OpenMPOp):
)
if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
cond1 = " && ".join(
[
"PyArray_ISCONTIGUOUS(%s)" % arr
for arr, var in z
if not all(var.broadcastable)
if not all_broadcastable
]
)
cond2 = " && ".join(
[
"PyArray_ISFORTRAN(%s)" % arr
for arr, var in z
if not all(var.broadcastable)
if not all_broadcastable
]
)
loop = (
......@@ -1388,13 +1389,7 @@ class CAReduce(COp):
axis = self.axis
if axis is None:
return ((),)
return (
[
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in axis
],
)
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)
def _c_all(self, node, name, inames, onames, sub):
......@@ -1419,7 +1414,7 @@ class CAReduce(COp):
axis = self.axis
if axis is None:
axis = list(range(len(input.type.broadcastable)))
axis = list(range(input.type.ndim))
if len(axis) == 0:
# The acc_dtype is never a downcast compared to the input dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论