提交 22416ba6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Elemwise use with scalar Ops in broadcast_shape_iter

上级 1615f991
from collections.abc import Collection from collections.abc import Collection
from functools import reduce
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union
import numpy as np import numpy as np
...@@ -6,6 +7,7 @@ import numpy.core.numeric ...@@ -6,6 +7,7 @@ import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
import aesara import aesara
import aesara.scalar.basic as aes
from aesara.gradient import ( from aesara.gradient import (
DisconnectedType, DisconnectedType,
_float_zeros_like, _float_zeros_like,
...@@ -26,9 +28,7 @@ from aesara.tensor import get_vector_length ...@@ -26,9 +28,7 @@ from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq, ge, lt from aesara.tensor.math import ge, lt, maximum, minimum, prod
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, minimum, or_, prod
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
...@@ -1534,13 +1534,19 @@ def broadcast_shape_iter( ...@@ -1534,13 +1534,19 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0]) result_dims.append(maybe_non_bcast_shapes[0])
continue continue
non_bcast_vec = at.as_tensor(maybe_non_bcast_shapes) non_bcast_vec = [
non_bcast_vec = at.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec) aes.switch(aes.eq(nbv, 1), -one_at, nbv)
dim_max = at_abs(at_max(non_bcast_vec)) for nbv in maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
assert_dim = Assert("Could not broadcast dimensions") assert_dim = Assert("Could not broadcast dimensions")
assert_cond = at_all( assert_cond = reduce(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max)) aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
for nbv in non_bcast_vec
),
) )
bcast_dim = assert_dim(dim_max, assert_cond) bcast_dim = assert_dim(dim_max, assert_cond)
......
...@@ -2890,10 +2890,10 @@ class TestShapeI(utt.InferShapeTester): ...@@ -2890,10 +2890,10 @@ class TestShapeI(utt.InferShapeTester):
self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i) self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i)
class TestShapeFeature: class TestSameShape:
def test_scalar(self): def test_scalar(self):
x = scalar() x = scalar()
cst = at.constant(1).clone() cst = at.constant(1)
o = x + cst o = x + cst
fgraph = FunctionGraph([x], [o], clone=False) fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
...@@ -2902,34 +2902,42 @@ class TestShapeFeature: ...@@ -2902,34 +2902,42 @@ class TestShapeFeature:
def test_vector(self): def test_vector(self):
x = vector() x = vector()
cst = at.constant(1).clone() cst = at.constant(1)
o = x + cst o = x + cst
fgraph = FunctionGraph([x], [o], clone=False) fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature) fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o) assert shape_feature.same_shape(x, o)
def test_vector2(self): def test_no_static_shapes(self):
x = vector() x = vector()
y = vector() y = vector()
o = x + y o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False) fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature) fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o) # We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented # The following case isn't implemented
assert not shape_feature.same_shape(y, o) assert not shape_feature.same_shape(y, o)
def test_vector_dim(self): @pytest.mark.parametrize(
x = vector() "y_dim_0",
y = vector() [2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
y = at.tensor(dtype="floatX", shape=(y_dim_0, None))
o = x + y o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False) fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature) fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0) assert shape_feature.same_shape(x, o, 0, 0)
# The following case isn't implemented assert not shape_feature.same_shape(x, o, 1, 1)
assert not shape_feature.same_shape(y, o, 0, 0)
def test_vector_dim_err(self): def test_vector_dim_err(self):
x = vector() x = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论