提交 9ca3be9c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Convert python literals to theano constants before handing them off to Subtensor.

上级 5420caac
......@@ -4,7 +4,8 @@ import numpy
import theano
from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError
from theano.scalar import (ComplexError, IntegerDivisionError,
ScalarConstant, int64)
from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray
......@@ -348,6 +349,19 @@ class _tensor_py_operators:
def __getitem__(self, args):
if not isinstance(args, tuple):
args = args,
# Convert python literals to theano constants
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论