提交 507624c4 authored 作者: Frederic's avatar Frederic

Make the assert compare with equivalent tensorconstant and bypass a few op.

上级 61b8fea3
......@@ -821,7 +821,9 @@ class ShapeFeature(object):
else:
shape_vars.append(self.unpack(s[i]))
assert all([not r.type.broadcastable[i] or
shape_vars[i] == self.lscalar_one
self.lscalar_one.equals(shape_vars[i]) or
self.lscalar_one.equals(
T.extract_constant(shape_vars[i]))
for i in range(r.ndim)])
self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars:
......@@ -866,7 +868,9 @@ class ShapeFeature(object):
merged_shape.append(other_shape[i])
assert all([(not r.type.broadcastable[i] and
not other_r.type.broadcastable[i]) or
merged_shape[i] == self.lscalar_one
self.lscalar_one.equals(merged_shape[i]) or
self.lscalar_one.equals(
T.extract_constant(merged_shape[i]))
for i in range(r.ndim)])
self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]:
......@@ -885,7 +889,8 @@ class ShapeFeature(object):
else:
new_shape.append(s_j)
assert all([not r.type.broadcastable[i] or
new_shape[i] == self.lscalar_one
self.lscalar_one.equals(new_shape[i]) or
self.lscalar_one.equals(T.extract_constant(new_shape[i]))
for i in range(r.ndim)])
self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论