提交 507624c4 authored 作者: Frederic's avatar Frederic

Make the assert compare with equivalent tensorconstant and bypass a few op.

上级 61b8fea3
...@@ -821,7 +821,9 @@ class ShapeFeature(object): ...@@ -821,7 +821,9 @@ class ShapeFeature(object):
else: else:
shape_vars.append(self.unpack(s[i])) shape_vars.append(self.unpack(s[i]))
assert all([not r.type.broadcastable[i] or assert all([not r.type.broadcastable[i] or
shape_vars[i] == self.lscalar_one self.lscalar_one.equals(shape_vars[i]) or
self.lscalar_one.equals(
T.extract_constant(shape_vars[i]))
for i in range(r.ndim)]) for i in range(r.ndim)])
self.shape_of[r] = tuple(shape_vars) self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars: for sv in shape_vars:
...@@ -866,7 +868,9 @@ class ShapeFeature(object): ...@@ -866,7 +868,9 @@ class ShapeFeature(object):
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
assert all([(not r.type.broadcastable[i] and assert all([(not r.type.broadcastable[i] and
not other_r.type.broadcastable[i]) or not other_r.type.broadcastable[i]) or
merged_shape[i] == self.lscalar_one self.lscalar_one.equals(merged_shape[i]) or
self.lscalar_one.equals(
T.extract_constant(merged_shape[i]))
for i in range(r.ndim)]) for i in range(r.ndim)])
self.shape_of[r] = tuple(merged_shape) self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
...@@ -885,7 +889,8 @@ class ShapeFeature(object): ...@@ -885,7 +889,8 @@ class ShapeFeature(object):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
assert all([not r.type.broadcastable[i] or assert all([not r.type.broadcastable[i] or
new_shape[i] == self.lscalar_one self.lscalar_one.equals(new_shape[i]) or
self.lscalar_one.equals(T.extract_constant(new_shape[i]))
for i in range(r.ndim)]) for i in range(r.ndim)])
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论