提交 69456ab6 authored 作者: Frederic Bastien's avatar Frederic Bastien

In the ShapeFeature, convert x.shape[1] to Shape_i. This will help compare shapes.

上级 c644c16b
......@@ -799,7 +799,21 @@ class ShapeFeature(object):
#
# worst case, we loop over shape_of and replace things
raise NotImplementedError(s_i)
elif s_i.type.dtype[:3] in ('int', 'uint'):
# s_i is x.shape[i], we change it to Shape_i.
if (s_i.owner and
isinstance(s_i.owner.op, Subtensor) and
s_i.owner.inputs[0].owner and
isinstance(s_i.owner.inputs[0].owner.op, T.Shape)):
assert s_i.ndim == 0
assert len(s_i.owner.inputs) == 2
try:
i = get_scalar_constant_value(s_i.owner.inputs[1])
s_i = Shape_i(i)(s_i.owner.inputs[0].owner.inputs[0])
except NotScalarConstantError:
pass
if s_i.type.dtype[:3] in ('int', 'uint'):
if getattr(s_i.type, 'ndim', 0):
raise TypeError('Shape element must be scalar', s_i)
return s_i
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论