提交 a1739f6c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace some use of broadcastable with shape in tests.tensor.test_subtensor

上级 8761c77a
...@@ -465,7 +465,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -465,7 +465,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
def test_ok_row(self): def test_ok_row(self):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3))) n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[1] t = n[1]
assert not any(n.type.broadcastable) assert not any(s == 1 for s in n.type.shape)
assert isinstance(t.owner.op, Subtensor) assert isinstance(t.owner.op, Subtensor)
tval = self.eval_output_and_check(t) tval = self.eval_output_and_check(t)
assert tval.shape == (3,) assert tval.shape == (3,)
...@@ -475,7 +475,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -475,7 +475,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3))) n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[:, 0] t = n[:, 0]
assert isinstance(t.owner.op, Subtensor) assert isinstance(t.owner.op, Subtensor)
assert not any(n.type.broadcastable) assert not any(s == 1 for s in n.type.shape)
tval = self.eval_output_and_check(t) tval = self.eval_output_and_check(t)
assert tval.shape == (2,) assert tval.shape == (2,)
assert np.all(tval == [0, 3]) assert np.all(tval == [0, 3])
...@@ -1773,15 +1773,17 @@ class TestAdvancedSubtensor: ...@@ -1773,15 +1773,17 @@ class TestAdvancedSubtensor:
def test_index_into_vec_w_matrix(self): def test_index_into_vec_w_matrix(self):
a = self.v[self.ix2] a = self.v[self.ix2]
assert a.dtype == self.v.dtype, (a.dtype, self.v.dtype) assert a.dtype == self.v.dtype, (a.dtype, self.v.dtype)
assert a.broadcastable == self.ix2.broadcastable, ( assert a.type.ndim == self.ix2.type.ndim
a.broadcastable, assert all(
self.ix2.broadcastable, s1 == s2
for s1, s2 in zip(a.type.shape, self.ix2.type.shape)
if s1 == 1 or s2 == 1
) )
def test_index_into_mat_w_row(self): def test_index_into_mat_w_row(self):
a = self.m[self.ixr] a = self.m[self.ixr]
assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype) assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype)
assert a.broadcastable == (True, False, False) assert a.type.shape == (1, None, None)
def test_index_w_int_and_vec(self): def test_index_w_int_and_vec(self):
# like test_ok_list, but with a single index on the first one # like test_ok_list, but with a single index on the first one
...@@ -2447,7 +2449,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2447,7 +2449,7 @@ class TestInferShape(utt.InferShapeTester):
) )
abs_res = n[~isinf(n)] abs_res = n[~isinf(n)]
assert abs_res.broadcastable == (False,) assert abs_res.type.shape == (None,)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
...@@ -2468,9 +2470,7 @@ def idx_as_tensor(x): ...@@ -2468,9 +2470,7 @@ def idx_as_tensor(x):
def bcast_shape_tuple(x): def bcast_shape_tuple(x):
if not hasattr(x, "shape"): if not hasattr(x, "shape"):
return x return x
return tuple( return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape))
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True])) test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论