提交 9ca3be9c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Convert python literals to theano constants before handing them off to Subtensor.

上级 5420caac
...@@ -4,7 +4,8 @@ import numpy ...@@ -4,7 +4,8 @@ import numpy
import theano import theano
from theano.compat import all, PY3 from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import (ComplexError, IntegerDivisionError,
ScalarConstant, int64)
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
...@@ -348,6 +349,19 @@ class _tensor_py_operators: ...@@ -348,6 +349,19 @@ class _tensor_py_operators:
def __getitem__(self, args): def __getitem__(self, args):
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
# Convert python literals to theano constants
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds, # The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论