提交 a1739f6c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace some use of broadcastable with shape in tests.tensor.test_subtensor

上级 8761c77a
......@@ -465,7 +465,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
def test_ok_row(self):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[1]
assert not any(n.type.broadcastable)
assert not any(s == 1 for s in n.type.shape)
assert isinstance(t.owner.op, Subtensor)
tval = self.eval_output_and_check(t)
assert tval.shape == (3,)
......@@ -475,7 +475,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[:, 0]
assert isinstance(t.owner.op, Subtensor)
assert not any(n.type.broadcastable)
assert not any(s == 1 for s in n.type.shape)
tval = self.eval_output_and_check(t)
assert tval.shape == (2,)
assert np.all(tval == [0, 3])
......@@ -1773,15 +1773,17 @@ class TestAdvancedSubtensor:
def test_index_into_vec_w_matrix(self):
a = self.v[self.ix2]
assert a.dtype == self.v.dtype, (a.dtype, self.v.dtype)
assert a.broadcastable == self.ix2.broadcastable, (
a.broadcastable,
self.ix2.broadcastable,
assert a.type.ndim == self.ix2.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(a.type.shape, self.ix2.type.shape)
if s1 == 1 or s2 == 1
)
def test_index_into_mat_w_row(self):
a = self.m[self.ixr]
assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype)
assert a.broadcastable == (True, False, False)
assert a.type.shape == (1, None, None)
def test_index_w_int_and_vec(self):
# like test_ok_list, but with a single index on the first one
......@@ -2447,7 +2449,7 @@ class TestInferShape(utt.InferShapeTester):
)
abs_res = n[~isinf(n)]
assert abs_res.broadcastable == (False,)
assert abs_res.type.shape == (None,)
@config.change_flags(compute_test_value="raise")
......@@ -2468,9 +2470,7 @@ def idx_as_tensor(x):
def bcast_shape_tuple(x):
if not hasattr(x, "shape"):
return x
return tuple(
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape))
test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论