提交 22416ba6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Elemwise use with scalar Ops in broadcast_shape_iter

上级 1615f991
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Tuple, Union
import numpy as np
......@@ -6,6 +7,7 @@ import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index
import aesara
import aesara.scalar.basic as aes
from aesara.gradient import (
DisconnectedType,
_float_zeros_like,
......@@ -26,9 +28,7 @@ from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq, ge, lt
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, minimum, or_, prod
from aesara.tensor.math import ge, lt, maximum, minimum, prod
from aesara.tensor.math import sum as at_sum
from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
......@@ -1534,13 +1534,19 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0])
continue
non_bcast_vec = at.as_tensor(maybe_non_bcast_shapes)
non_bcast_vec = at.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
dim_max = at_abs(at_max(non_bcast_vec))
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = at_all(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max))
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)
......
......@@ -2890,10 +2890,10 @@ class TestShapeI(utt.InferShapeTester):
self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i)
class TestShapeFeature:
class TestSameShape:
def test_scalar(self):
x = scalar()
cst = at.constant(1).clone()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
......@@ -2902,34 +2902,42 @@ class TestShapeFeature:
def test_vector(self):
x = vector()
cst = at.constant(1).clone()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector2(self):
def test_no_static_shapes(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
def test_vector_dim(self):
x = vector()
y = vector()
@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
y = at.tensor(dtype="floatX", shape=(y_dim_0, None))
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o, 0, 0)
assert not shape_feature.same_shape(x, o, 1, 1)
def test_vector_dim_err(self):
x = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论