提交 9bdef90d authored 作者: abergeron's avatar abergeron

Merge pull request #3385 from lamblin/fix_shape_i

Use shape_of instead of explicitly inserting Shape_i
...@@ -924,7 +924,7 @@ class ShapeFeature(object): ...@@ -924,7 +924,7 @@ class ShapeFeature(object):
# worst case, we loop over shape_of and replace things # worst case, we loop over shape_of and replace things
raise NotImplementedError(s_i) raise NotImplementedError(s_i)
# s_i is x.shape[i], we change it to Shape_i. # s_i is x.shape[i] for some x, we change it to shape_of[x][i]
if (s_i.owner and if (s_i.owner and
isinstance(s_i.owner.op, Subtensor) and isinstance(s_i.owner.op, Subtensor) and
s_i.owner.inputs[0].owner and s_i.owner.inputs[0].owner and
...@@ -940,9 +940,13 @@ class ShapeFeature(object): ...@@ -940,9 +940,13 @@ class ShapeFeature(object):
idx = idx[0] idx = idx[0]
try: try:
i = get_scalar_constant_value(idx) i = get_scalar_constant_value(idx)
s_i = Shape_i(i)(s_i.owner.inputs[0].owner.inputs[0])
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else:
# Executed only if no exception was raised
x = s_i.owner.inputs[0].owner.inputs[0]
# x should already have been imported, and should be in shape_of.
s_i = self.shape_of[x][i]
if s_i.type.dtype[:3] in ('int', 'uint'): if s_i.type.dtype[:3] in ('int', 'uint'):
if getattr(s_i.type, 'ndim', 0): if getattr(s_i.type, 'ndim', 0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论