Fixed assert in local_subtensor_make_vector. ShapeFeature's unpack now deals…

Fixed assert in local_subtensor_make_vector. ShapeFeature's unpack now deals with numpy.integer as well.
上级 cee214f2
...@@ -375,7 +375,7 @@ class ShapeFeature(object): ...@@ -375,7 +375,7 @@ class ShapeFeature(object):
if s_i == 1: if s_i == 1:
# don't make the optimizer merge a zillion ones together # don't make the optimizer merge a zillion ones together
return self.lscalar_one return self.lscalar_one
if type(s_i) is int: if type(s_i) is int or isinstance(s_i, numpy.integer):
# this shape is a constant # this shape is a constant
assert s_i >= 0 assert s_i >= 0
return T.constant(s_i, dtype='int64') return T.constant(s_i, dtype='int64')
...@@ -573,7 +573,7 @@ def local_subtensor_make_vector(node): ...@@ -573,7 +573,7 @@ def local_subtensor_make_vector(node):
# The idx is a Scalar, ie a Type. This means the actual index # The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1] # is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert isinstance(idx, old_idx) assert idx.type == old_idx
if isinstance(idx, (int, numpy.integer)): if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论