提交 7d72236a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace use of broadcastable with shape in aesara.tensor.shape

上级 46a46af2
...@@ -364,8 +364,8 @@ class ShapeFeature(Feature): ...@@ -364,8 +364,8 @@ class ShapeFeature(Feature):
else: else:
shape_vars.append(self.unpack(s[i], r)) shape_vars.append(self.unpack(s[i], r))
assert all( assert all(
not hasattr(r.type, "broadcastable") not hasattr(r.type, "shape")
or not r.type.broadcastable[i] or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i]) or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i])) or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim) for i in range(r.type.ndim)
...@@ -447,9 +447,9 @@ class ShapeFeature(Feature): ...@@ -447,9 +447,9 @@ class ShapeFeature(Feature):
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
assert all( assert all(
( (
not hasattr(r.type, "broadcastable") not hasattr(r.type, "shape")
or not r.type.broadcastable[i] or r.type.shape[i] != 1
and not other_r.type.broadcastable[i] and other_r.type.shape[i] != 1
) )
or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
...@@ -474,8 +474,8 @@ class ShapeFeature(Feature): ...@@ -474,8 +474,8 @@ class ShapeFeature(Feature):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
assert all( assert all(
not hasattr(r.type, "broadcastable") not hasattr(r.type, "shape")
or not r.type.broadcastable[idx] or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx])) or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim) for idx in range(r.type.ndim)
...@@ -781,7 +781,11 @@ def local_reshape_chain(op): ...@@ -781,7 +781,11 @@ def local_reshape_chain(op):
# We should try to figure out why we lost the information about this # We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this # constant value... but in the meantime, better not apply this
# rewrite. # rewrite.
if rval.broadcastable == node.outputs[0].broadcastable: if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s1
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval] return [rval]
else: else:
return False return False
...@@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node): ...@@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node):
if ( if (
inp.type.ndim == 1 inp.type.ndim == 1
and output.type.ndim == 1 and output.type.ndim == 1
and inp.type.broadcastable == output.type.broadcastable and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
): ):
return [inp] return [inp]
...@@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node): ...@@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True shape_match[dim] = True
continue continue
# Match 1 if input.broadcastable[dim] is True # Match 1 if input.type.shape[dim] == 1
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1: if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
shape_match[dim] = True shape_match[dim] = True
...@@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim: if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape) inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner) copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] new_node = [
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
copy_stack_trace(output, new_node) copy_stack_trace(output, new_node)
return new_node return new_node
...@@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
new_order = node.inputs[0].owner.op.new_order new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0] inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = [] new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables): for i, s in zip(new_order, node.inputs[0].type.shape):
if not bd: if s != 1:
new_order_of_nonbroadcast.append(i) new_order_of_nonbroadcast.append(i)
no_change_in_order = all( no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
...@@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node): ...@@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node):
""" """
if isinstance(node.op, Unbroadcast): if isinstance(node.op, Unbroadcast):
x = node.inputs[0] x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable: if x.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
# No broadcastable flag was modified # No broadcastable flag was modified
# No need to copy over stack trace, # No need to copy over stack trace,
# because x should already have a stack trace. # because x should already have a stack trace.
......
...@@ -55,13 +55,13 @@ from tests.test_rop import RopLopChecker ...@@ -55,13 +55,13 @@ from tests.test_rop import RopLopChecker
def test_shape_basic(): def test_shape_basic():
s = shape([]) s = shape([])
assert s.type.broadcastable == (True,) assert s.type.shape == (1,)
s = shape([10]) s = shape([10])
assert s.type.broadcastable == (True,) assert s.type.shape == (1,)
s = shape(lscalar()) s = shape(lscalar())
assert s.type.broadcastable == (False,) assert s.type.shape == (0,)
class MyType(Type): class MyType(Type):
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
...@@ -71,7 +71,7 @@ def test_shape_basic(): ...@@ -71,7 +71,7 @@ def test_shape_basic():
return isinstance(other, MyType) and other.thingy == self.thingy return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType(), None)) s = shape(Variable(MyType(), None))
assert s.type.broadcastable == (False,) assert s.type.shape == (None,)
s = shape(np.array(1)) s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), []) assert np.array_equal(eval_outputs([s]), [])
...@@ -119,15 +119,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -119,15 +119,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
b = dmatrix() b = dmatrix()
d = dmatrix() d = dmatrix()
# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)
b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]]) b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]])
c_val1 = np.asarray([0, 1, 2, 3, 4, 5]) c_val1 = np.asarray([0, 1, 2, 3, 4, 5])
b_val2 = b_val1.T b_val2 = b_val1.T
c_val2 = np.asarray([0, 3, 1, 4, 2, 5]) c_val2 = np.asarray([0, 3, 1, 4, 2, 5])
# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)
f_out1 = f(b_val1) f_out1 = f(b_val1)
f_out2 = f(b_val2) f_out2 = f(b_val2)
assert np.array_equal(f_out1, c_val1), (f_out1, c_val1) assert np.array_equal(f_out1, c_val1), (f_out1, c_val1)
...@@ -191,10 +190,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -191,10 +190,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(np.asarray([[0, 1, 2], [3, 4, 5]])), f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1], [2]], [[3], [4], [5]]]), np.asarray([[[0], [1], [2]], [[3], [4], [5]]]),
) )
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == ( assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
False, None,
False, None,
True, 1,
) )
# test broadcast flag for constant value of 1 if it cannot be # test broadcast flag for constant value of 1 if it cannot be
...@@ -205,10 +204,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -205,10 +204,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(np.asarray([[0, 1, 2], [3, 4, 5]])), f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]), np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]),
) )
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == ( assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
False, None,
False, None,
True, 1,
) )
def test_m1(self): def test_m1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论