提交 46a46af2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.basic

上级 94c2e4c2
...@@ -28,7 +28,7 @@ from aesara.graph.basic import Apply, Constant, Variable ...@@ -28,7 +28,7 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type from aesara.graph.type import HasShape, Type
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -348,8 +348,8 @@ def get_scalar_constant_value( ...@@ -348,8 +348,8 @@ def get_scalar_constant_value(
if isinstance(inp, Constant): if isinstance(inp, Constant):
return np.asarray(np.shape(inp.data)[i]) return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1 # The shape of a broadcastable dimension is 1
if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]: if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
return np.asarray(1) return np.asarray(inp.type.shape[i])
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
...@@ -502,21 +502,16 @@ def get_scalar_constant_value( ...@@ -502,21 +502,16 @@ def get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable gp_shape = grandparent.type.shape
ndim = grandparent.type.ndim ndim = grandparent.type.ndim
if grandparent.owner and isinstance( if grandparent.owner and isinstance(
grandparent.owner.op, Unbroadcast grandparent.owner.op, Unbroadcast
): ):
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable ggp_shape = grandparent.owner.inputs[0].type.shape
l = [ l = [s1 == 1 or s2 == 1 for s1, s2 in zip(ggp_shape, gp_shape)]
b1 or b2 gp_shape = tuple(l)
for b1, b2 in zip(ggp_broadcastable, gp_broadcastable)
]
gp_broadcastable = tuple(l)
assert ndim == len(gp_broadcastable)
if not (idx < len(gp_broadcastable)): if not (idx < ndim):
msg = ( msg = (
"get_scalar_constant_value detected " "get_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] " f"deterministic IndexError: x.shape[{int(idx)}] "
...@@ -528,8 +523,9 @@ def get_scalar_constant_value( ...@@ -528,8 +523,9 @@ def get_scalar_constant_value(
msg += f" x={v}" msg += f" x={v}"
raise ValueError(msg) raise ValueError(msg)
if gp_broadcastable[idx]: gp_shape_val = gp_shape[idx]
return np.asarray(1) if gp_shape_val is not None and gp_shape_val > -1:
return np.asarray(gp_shape_val)
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx]) return np.asarray(np.shape(grandparent.data)[idx])
...@@ -1511,15 +1507,16 @@ class Alloc(COp): ...@@ -1511,15 +1507,16 @@ class Alloc(COp):
axis_kept = [] axis_kept = []
for i, (ib, gb) in enumerate( for i, (ib, gb) in enumerate(
zip( zip(
inputs[0].broadcastable, inputs[0].type.shape,
# We need the dimensions corresponding to x # We need the dimensions corresponding to x
grads[0].broadcastable[-inputs[0].ndim :], grads[0].type.shape[-inputs[0].ndim :],
) )
): ):
if ib and not gb: if ib == 1 and gb != 1:
axis_broadcasted.append(i + n_axes_to_sum) axis_broadcasted.append(i + n_axes_to_sum)
else: else:
axis_kept.append(i) axis_kept.append(i)
gx = gz.sum(axis=axis + axis_broadcasted) gx = gz.sum(axis=axis + axis_broadcasted)
if axis_broadcasted: if axis_broadcasted:
new_order = ["x"] * x.ndim new_order = ["x"] * x.ndim
...@@ -1865,11 +1862,14 @@ def transpose(x, axes=None): ...@@ -1865,11 +1862,14 @@ def transpose(x, axes=None):
""" """
_x = as_tensor_variable(x) _x = as_tensor_variable(x)
if axes is None: if axes is None:
axes = list(range((_x.ndim - 1), -1, -1)) axes = list(range((_x.type.ndim - 1), -1, -1))
ret = DimShuffle(_x.broadcastable, axes)(_x) ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
if _x.name and axes == list(range((_x.ndim - 1), -1, -1)):
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T" ret.name = _x.name + ".T"
return ret return ret
...@@ -3207,11 +3207,11 @@ class PermuteRowElements(Op): ...@@ -3207,11 +3207,11 @@ class PermuteRowElements(Op):
if xs0 == ys0: if xs0 == ys0:
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1) self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1)
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]: elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1:
# Broadcast y # Broadcast y
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1) self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]: elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1:
# Broadcast x # Broadcast x
for i in range(ys0): for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1) self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1)
...@@ -3270,7 +3270,7 @@ class PermuteRowElements(Op): ...@@ -3270,7 +3270,7 @@ class PermuteRowElements(Op):
broadcasted_dims = [ broadcasted_dims = [
dim dim
for dim in range(gz.type.ndim) for dim in range(gz.type.ndim)
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim] if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1
] ]
gx = Sum(axis=broadcasted_dims)(gx) gx = Sum(axis=broadcasted_dims)(gx)
...@@ -3285,8 +3285,13 @@ class PermuteRowElements(Op): ...@@ -3285,8 +3285,13 @@ class PermuteRowElements(Op):
newdims.append(i) newdims.append(i)
i += 1 i += 1
gx = DimShuffle(gx.type.broadcastable, newdims)(gx) gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
assert gx.type.broadcastable == x.type.broadcastable assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(gx.type.shape, x.type.shape)
if s1 == 1 or s2 == 1
)
# if x is an integer type, then so is the output. # if x is an integer type, then so is the output.
# this means f(x+eps) = f(x) so the gradient with respect # this means f(x+eps) = f(x) so the gradient with respect
......
...@@ -458,10 +458,10 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -458,10 +458,10 @@ class TestMakeVector(utt.InferShapeTester):
res = MakeVector("int32")(a, b) res = MakeVector("int32")(a, b)
res = MakeVector()(a) res = MakeVector()(a)
assert res.broadcastable == (True,) assert res.type.shape == (1,)
res = MakeVector()() res = MakeVector()()
assert res.broadcastable == (False,) assert res.type.shape == (0,)
def test_infer_shape(self): def test_infer_shape(self):
adscal = dscalar() adscal = dscalar()
...@@ -1665,18 +1665,18 @@ class TestJoinAndSplit: ...@@ -1665,18 +1665,18 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(None, None, 1)) a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1)) b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(1, a, b) c = self.join_op(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2] assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert not c.type.broadcastable[1] assert c.type.shape[1] != 1
# Opt can remplace the int by an Aesara constant # Opt can remplace the int by an Aesara constant
c = self.join_op(constant(1), a, b) c = self.join_op(constant(1), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2] assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert not c.type.broadcastable[1] assert c.type.shape[1] != 1
# In case futur opt insert other useless stuff # In case futur opt insert other useless stuff
c = self.join_op(cast(constant(1), dtype="int32"), a, b) c = self.join_op(cast(constant(1), dtype="int32"), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2] assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert not c.type.broadcastable[1] assert c.type.shape[1] != 1
f = function([], c, mode=self.mode) f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -1703,7 +1703,7 @@ class TestJoinAndSplit: ...@@ -1703,7 +1703,7 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(None, None, 1)) a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1)) b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b) c = self.join_op(0, a, b)
assert not c.type.broadcastable[0] assert c.type.shape[0] != 1
f = function([], c, mode=self.mode) f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -1736,7 +1736,7 @@ class TestJoinAndSplit: ...@@ -1736,7 +1736,7 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(1, None, 1)) a = self.shared(a_val, shape=(1, None, 1))
b = self.shared(b_val, shape=(1, None, 1)) b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b) c = self.join_op(0, a, b)
assert not c.type.broadcastable[0] assert c.type.shape[0] != 1
f = function([], c, mode=self.mode) f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -1754,9 +1754,9 @@ class TestJoinAndSplit: ...@@ -1754,9 +1754,9 @@ class TestJoinAndSplit:
a_val = rng.random((1, 4, 1)).astype(self.floatX) a_val = rng.random((1, 4, 1)).astype(self.floatX)
a = self.shared(a_val, shape=(1, None, 1)) a = self.shared(a_val, shape=(1, None, 1))
b = self.join_op(0, a) b = self.join_op(0, a)
assert b.type.broadcastable[0] assert b.type.shape[0] == 1
assert b.type.broadcastable[2] assert b.type.shape[2] == 1
assert not b.type.broadcastable[1] assert b.type.shape[1] != 1
f = function([], b, mode=self.mode) f = function([], b, mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -1782,13 +1782,13 @@ class TestJoinAndSplit: ...@@ -1782,13 +1782,13 @@ class TestJoinAndSplit:
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))() d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))() e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
f = self.join_op(0, a, b, c, d, e) f = self.join_op(0, a, b, c, d, e)
fb = f.type.broadcastable fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5] assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
g = self.join_op(1, a, b, c, d, e) g = self.join_op(1, a, b, c, d, e)
gb = g.type.broadcastable gb = tuple(s == 1 for s in g.type.shape)
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5] assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
h = self.join_op(4, a, b, c, d, e) h = self.join_op(4, a, b, c, d, e)
hb = h.type.broadcastable hb = tuple(s == 1 for s in h.type.shape)
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5] assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
f = function([a, b, c, d, e], f, mode=self.mode) f = function([a, b, c, d, e], f, mode=self.mode)
...@@ -1981,8 +1981,8 @@ def test_TensorFromScalar(): ...@@ -1981,8 +1981,8 @@ def test_TensorFromScalar():
s = aes.constant(56) s = aes.constant(56)
t = tensor_from_scalar(s) t = tensor_from_scalar(s)
assert t.owner.op is tensor_from_scalar assert t.owner.op is tensor_from_scalar
assert t.type.broadcastable == (), t.type.broadcastable assert t.type.shape == ()
assert t.type.ndim == 0, t.type.ndim assert t.type.ndim == 0
assert t.type.dtype == s.type.dtype assert t.type.dtype == s.type.dtype
v = eval_outputs([t]) v = eval_outputs([t])
...@@ -2129,23 +2129,23 @@ def test_flatten_broadcastable(): ...@@ -2129,23 +2129,23 @@ def test_flatten_broadcastable():
inp = TensorType("float64", shape=(None, None, None, None))() inp = TensorType("float64", shape=(None, None, None, None))()
out = flatten(inp, ndim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, None, None, 1))() inp = TensorType("float64", shape=(None, None, None, 1))()
out = flatten(inp, ndim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, 1, None, 1))() inp = TensorType("float64", shape=(None, 1, None, 1))()
out = flatten(inp, ndim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False) assert out.type.shape == (None, None)
inp = TensorType("float64", shape=(None, 1, 1, 1))() inp = TensorType("float64", shape=(None, 1, 1, 1))()
out = flatten(inp, ndim=2) out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True) assert out.type.shape == (None, 1)
inp = TensorType("float64", shape=(1, None, 1, 1))() inp = TensorType("float64", shape=(1, None, 1, 1))()
out = flatten(inp, ndim=3) out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True) assert out.type.shape == (1, None, 1)
def test_flatten_ndim_invalid(): def test_flatten_ndim_invalid():
...@@ -2938,8 +2938,8 @@ class TestPermuteRowElements: ...@@ -2938,8 +2938,8 @@ class TestPermuteRowElements:
def test_3b_2(self): def test_3b_2(self):
# Test permute_row_elements on a more complex broadcasting pattern: # Test permute_row_elements on a more complex broadcasting pattern:
# input.type.broadcastable = (False, True, False), # input.type.shape = (None, 1, None),
# p.type.broadcastable = (False, False). # p.type.shape = (None, None).
input = TensorType("floatX", shape=(None, 1, None))() input = TensorType("floatX", shape=(None, 1, None))()
p = imatrix() p = imatrix()
...@@ -4046,7 +4046,7 @@ class TestChoose(utt.InferShapeTester): ...@@ -4046,7 +4046,7 @@ class TestChoose(utt.InferShapeTester):
B = np.asarray(np.random.random((4, 1)), dtype="float32") B = np.asarray(np.random.random((4, 1)), dtype="float32")
for m in self.modes: for m in self.modes:
f = function([a, b], choose(a, b, mode=m)) f = function([a, b], choose(a, b, mode=m))
assert choose(a, b, mode=m).broadcastable[0] assert choose(a, b, mode=m).type.shape[0] == 1
t_c = f(A, B) t_c = f(A, B)
n_c = np.choose(A, B, mode=m) n_c = np.choose(A, B, mode=m)
assert np.allclose(t_c, n_c) assert np.allclose(t_c, n_c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论