提交 512b3eac authored 作者: Frederic's avatar Frederic

Better implementation

上级 6a732dd7
......@@ -1180,7 +1180,7 @@ class ShapeFeature(object):
same shape.
dim_x and dim_y are optional. If used, they should be an index
to compare only 1 shape of x or y.
to compare only 1 dimension of x and y.
"""
sx = self.shape_of[x]
......@@ -1193,6 +1193,9 @@ class ShapeFeature(object):
sy = [sy[dim_y]]
assert len(sx) == len(sy)
# We look on each dimensions we want to compare.
# If any of them can't be asserted to be equal, return False.
# Otherwise, we return True at the end.
for dx, dy in zip(sx, sy):
if dx is dy:
continue
......@@ -1210,12 +1213,14 @@ class ShapeFeature(object):
return False
# FB I'm not sure is this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]:
return True
continue
# To be sure to cover all case, call equal_computation.
# Can't use theano.gof.graph.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph
from theano.scan_module.scan_utils import equal_computations
return equal_computations([dx], [dy])
if not equal_computations([dx], [dy]):
return False
return True
class ShapeOptimizer(Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论