提交 09a4d5ed authored 作者: Frederic Bastien's avatar Frederic Bastien

Add ShapeFeature.same_shape, that try to compare 2 var shapes.

上级 69456ab6
......@@ -1145,6 +1145,40 @@ class ShapeFeature(object):
self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set()
def same_shape(self, x, y):
"""Return True if we are able to assert that x and y have the
same shape
"""
sx = self.shape_of[x]
sy = self.shape_of[y]
if sx is None or sy is None:
return False
assert len(sx) == len(sy)
for dx, dy in zip(sx, sy):
if dx is dy:
continue
# Need to try to find that they are the same shape. We
# need to compare the full graph. It could be slow. So I
# just implement for now the case of Shape_i.
if not dx.owner or not dy.owner:
return False
if (not isinstance(dx.owner.op, Shape_i) or
not isinstance(dy.owner.op, Shape_i)):
return False
opx = dx.owner.op
opy = dy.owner.op
if not (opx.i == opy.i):
return False
# FB I'm not sure is this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]:
return True
# To be sure to cover all case, call equal_computation.
# Can't use theano.gof.graph.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph
from theano.scan_module.scan_utils import equal_computations
return equal_computations([dx], [dy])
class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论