提交 46a46af2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.basic

上级 94c2e4c2
......@@ -28,7 +28,7 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type
from aesara.graph.type import HasShape, Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
......@@ -348,8 +348,8 @@ def get_scalar_constant_value(
if isinstance(inp, Constant):
return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1
if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]:
return np.asarray(1)
if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
return np.asarray(inp.type.shape[i])
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
......@@ -502,21 +502,16 @@ def get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable
gp_shape = grandparent.type.shape
ndim = grandparent.type.ndim
if grandparent.owner and isinstance(
grandparent.owner.op, Unbroadcast
):
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
l = [
b1 or b2
for b1, b2 in zip(ggp_broadcastable, gp_broadcastable)
]
gp_broadcastable = tuple(l)
ggp_shape = grandparent.owner.inputs[0].type.shape
l = [s1 == 1 or s2 == 1 for s1, s2 in zip(ggp_shape, gp_shape)]
gp_shape = tuple(l)
assert ndim == len(gp_broadcastable)
if not (idx < len(gp_broadcastable)):
if not (idx < ndim):
msg = (
"get_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] "
......@@ -528,8 +523,9 @@ def get_scalar_constant_value(
msg += f" x={v}"
raise ValueError(msg)
if gp_broadcastable[idx]:
return np.asarray(1)
gp_shape_val = gp_shape[idx]
if gp_shape_val is not None and gp_shape_val > -1:
return np.asarray(gp_shape_val)
if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])
......@@ -1511,15 +1507,16 @@ class Alloc(COp):
axis_kept = []
for i, (ib, gb) in enumerate(
zip(
inputs[0].broadcastable,
inputs[0].type.shape,
# We need the dimensions corresponding to x
grads[0].broadcastable[-inputs[0].ndim :],
grads[0].type.shape[-inputs[0].ndim :],
)
):
if ib and not gb:
if ib == 1 and gb != 1:
axis_broadcasted.append(i + n_axes_to_sum)
else:
axis_kept.append(i)
gx = gz.sum(axis=axis + axis_broadcasted)
if axis_broadcasted:
new_order = ["x"] * x.ndim
......@@ -1865,11 +1862,14 @@ def transpose(x, axes=None):
"""
_x = as_tensor_variable(x)
if axes is None:
axes = list(range((_x.ndim - 1), -1, -1))
ret = DimShuffle(_x.broadcastable, axes)(_x)
if _x.name and axes == list(range((_x.ndim - 1), -1, -1)):
axes = list(range((_x.type.ndim - 1), -1, -1))
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"
return ret
......@@ -3207,11 +3207,11 @@ class PermuteRowElements(Op):
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1)
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1)
......@@ -3270,7 +3270,7 @@ class PermuteRowElements(Op):
broadcasted_dims = [
dim
for dim in range(gz.type.ndim)
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]
if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1
]
gx = Sum(axis=broadcasted_dims)(gx)
......@@ -3285,8 +3285,13 @@ class PermuteRowElements(Op):
newdims.append(i)
i += 1
gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
assert gx.type.broadcastable == x.type.broadcastable
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(gx.type.shape, x.type.shape)
if s1 == 1 or s2 == 1
)
# if x is an integer type, then so is the output.
# this means f(x+eps) = f(x) so the gradient with respect
......
......@@ -458,10 +458,10 @@ class TestMakeVector(utt.InferShapeTester):
res = MakeVector("int32")(a, b)
res = MakeVector()(a)
assert res.broadcastable == (True,)
assert res.type.shape == (1,)
res = MakeVector()()
assert res.broadcastable == (False,)
assert res.type.shape == (0,)
def test_infer_shape(self):
adscal = dscalar()
......@@ -1665,18 +1665,18 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
# Opt can remplace the int by an Aesara constant
c = self.join_op(constant(1), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
# In case futur opt insert other useless stuff
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1703,7 +1703,7 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b)
assert not c.type.broadcastable[0]
assert c.type.shape[0] != 1
f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1736,7 +1736,7 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(1, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b)
assert not c.type.broadcastable[0]
assert c.type.shape[0] != 1
f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1754,9 +1754,9 @@ class TestJoinAndSplit:
a_val = rng.random((1, 4, 1)).astype(self.floatX)
a = self.shared(a_val, shape=(1, None, 1))
b = self.join_op(0, a)
assert b.type.broadcastable[0]
assert b.type.broadcastable[2]
assert not b.type.broadcastable[1]
assert b.type.shape[0] == 1
assert b.type.shape[2] == 1
assert b.type.shape[1] != 1
f = function([], b, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1782,13 +1782,13 @@ class TestJoinAndSplit:
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
f = self.join_op(0, a, b, c, d, e)
fb = f.type.broadcastable
fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
g = self.join_op(1, a, b, c, d, e)
gb = g.type.broadcastable
gb = tuple(s == 1 for s in g.type.shape)
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
h = self.join_op(4, a, b, c, d, e)
hb = h.type.broadcastable
hb = tuple(s == 1 for s in h.type.shape)
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
f = function([a, b, c, d, e], f, mode=self.mode)
......@@ -1981,8 +1981,8 @@ def test_TensorFromScalar():
s = aes.constant(56)
t = tensor_from_scalar(s)
assert t.owner.op is tensor_from_scalar
assert t.type.broadcastable == (), t.type.broadcastable
assert t.type.ndim == 0, t.type.ndim
assert t.type.shape == ()
assert t.type.ndim == 0
assert t.type.dtype == s.type.dtype
v = eval_outputs([t])
......@@ -2129,23 +2129,23 @@ def test_flatten_broadcastable():
inp = TensorType("float64", shape=(None, None, None, None))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, None, None, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, 1, None, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, 1, 1, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True)
assert out.type.shape == (None, 1)
inp = TensorType("float64", shape=(1, None, 1, 1))()
out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True)
assert out.type.shape == (1, None, 1)
def test_flatten_ndim_invalid():
......@@ -2938,8 +2938,8 @@ class TestPermuteRowElements:
def test_3b_2(self):
# Test permute_row_elements on a more complex broadcasting pattern:
# input.type.broadcastable = (False, True, False),
# p.type.broadcastable = (False, False).
# input.type.shape = (None, 1, None),
# p.type.shape = (None, None).
input = TensorType("floatX", shape=(None, 1, None))()
p = imatrix()
......@@ -4046,7 +4046,7 @@ class TestChoose(utt.InferShapeTester):
B = np.asarray(np.random.random((4, 1)), dtype="float32")
for m in self.modes:
f = function([a, b], choose(a, b, mode=m))
assert choose(a, b, mode=m).broadcastable[0]
assert choose(a, b, mode=m).type.shape[0] == 1
t_c = f(A, B)
n_c = np.choose(A, B, mode=m)
assert np.allclose(t_c, n_c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论