提交 d6d8c839 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More cases in local_subtensor_make_vector

上级 149c9b53
...@@ -567,15 +567,15 @@ def local_subtensor_make_vector(node): ...@@ -567,15 +567,15 @@ def local_subtensor_make_vector(node):
except: except:
#'how can you have multiple indexes into a shape?' #'how can you have multiple indexes into a shape?'
raise raise
if isinstance(idx, int): if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
elif isinstance(idx, T.TensorVariable): elif isinstance(idx, (T.TensorVariable, T.TensorConstant)):
# if it is a constant we can do something with it # if it is a constant we can do something with it
try: try:
v = get_constant_value(idx) v = get_constant_value(idx)
return [x.owner.inputs[v]] return [x.owner.inputs[v]]
except: except:
pass pass
else: else:
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
#TODO: check subtensor to see if it can contain constant variables, #TODO: check subtensor to see if it can contain constant variables,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论