提交 7d72236a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.shape

上级 46a46af2
......@@ -364,8 +364,8 @@ class ShapeFeature(Feature):
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
......@@ -447,9 +447,9 @@ class ShapeFeature(Feature):
merged_shape.append(other_shape[i])
assert all(
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
and other_r.type.shape[i] != 1
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
......@@ -474,8 +474,8 @@ class ShapeFeature(Feature):
else:
new_shape.append(s_j)
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[idx]
not hasattr(r.type, "shape")
or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
......@@ -781,7 +781,11 @@ def local_reshape_chain(op):
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.broadcastable == node.outputs[0].broadcastable:
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s1
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False
......@@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node):
if (
inp.type.ndim == 1
and output.type.ndim == 1
and inp.type.broadcastable == output.type.broadcastable
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
):
return [inp]
......@@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True
continue
# Match 1 if input.broadcastable[dim] is True
# Match 1 if input.type.shape[dim] == 1
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
shape_match[dim] = True
......@@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
new_node = [
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
copy_stack_trace(output, new_node)
return new_node
......@@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
for i, s in zip(new_order, node.inputs[0].type.shape):
if s != 1:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
......@@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node):
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable:
if x.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
......
......@@ -55,13 +55,13 @@ from tests.test_rop import RopLopChecker
def test_shape_basic():
s = shape([])
assert s.type.broadcastable == (True,)
assert s.type.shape == (1,)
s = shape([10])
assert s.type.broadcastable == (True,)
assert s.type.shape == (1,)
s = shape(lscalar())
assert s.type.broadcastable == (False,)
assert s.type.shape == (0,)
class MyType(Type):
def filter(self, *args, **kwargs):
......@@ -71,7 +71,7 @@ def test_shape_basic():
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType(), None))
assert s.type.broadcastable == (False,)
assert s.type.shape == (None,)
s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), [])
......@@ -119,15 +119,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
b = dmatrix()
d = dmatrix()
# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)
b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]])
c_val1 = np.asarray([0, 1, 2, 3, 4, 5])
b_val2 = b_val1.T
c_val2 = np.asarray([0, 3, 1, 4, 2, 5])
# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)
f_out1 = f(b_val1)
f_out2 = f(b_val2)
assert np.array_equal(f_out1, c_val1), (f_out1, c_val1)
......@@ -191,10 +190,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1], [2]], [[3], [4], [5]]]),
)
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
False,
False,
True,
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
None,
None,
1,
)
# test broadcast flag for constant value of 1 if it cannot be
......@@ -205,10 +204,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]),
)
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
False,
False,
True,
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
None,
None,
1,
)
def test_m1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论