提交 f0ac9811 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix a test error as Subtensor index as not always int64 but the shape feature give so.

Small refactoring at the same time.
上级 6f6981e9
...@@ -1266,7 +1266,10 @@ def local_useless_subtensor(node): ...@@ -1266,7 +1266,10 @@ def local_useless_subtensor(node):
except TypeError: except TypeError:
pass pass
if isinstance(idx.stop, theano.scalar.Scalar): if isinstance(idx.stop, int):
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
length_pos_shape_i = node.inputs[node_input_idx] length_pos_shape_i = node.inputs[node_input_idx]
# length_pos is a tensor variable, but length_pos_shape_i # length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent # is a scalar variable. We try to see if they represent
...@@ -1283,17 +1286,17 @@ def local_useless_subtensor(node): ...@@ -1283,17 +1286,17 @@ def local_useless_subtensor(node):
# We did not find underlying variables of the same type # We did not find underlying variables of the same type
return False return False
assert length_pos_shape_i.type == length_pos.type # The type can be different: int32 vs int64. length_pos
assert length_pos_shape_i.type.dtype == idx.stop.dtype # should always be int64 as that is what the shape
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
# as index type.
assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
"int32", "int64"]
# We already know that start and step are not variables # We already know that start and step are not variables
# and so they don't appear in the input of the node # and so they don't appear in the input of the node
node_input_idx += 1 node_input_idx += 1
if isinstance(idx.stop, int):
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
# length_pos_shape_i cannot be None # length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos: if length_pos_shape_i != length_pos:
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论