提交 dfbc1334 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add new subtensor helper make_constant and use it to fix a bug in local_subtensor_merge.

上级 85209fbb
......@@ -25,7 +25,8 @@ from theano.gof.utils import MethodNotDefined
from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, AdvancedIncSubtensor1)
Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1)
from theano import scalar
from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file
......@@ -1955,6 +1956,7 @@ def local_subtensor_merge(node):
else:
merged_slices += slices1[pos_1:]
merged_slices = make_constant(merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse(
merged_slices,
......
......@@ -47,6 +47,23 @@ class AdvancedIndexingError(TypeError):
# Helpful functions to deal with Subtensor and IncSubtensor
##########
def make_constant(args):
"""
Convert python litterals to theano constants in subtensor arguments.
"""
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return scal.ScalarConstant(scal.int64, a)
else:
return a
return tuple(map(conv, args))
def get_idx_list(inputs, idx_list):
'''
Given a list of inputs to the subtensor and its idx_list reorders
......
......@@ -4,8 +4,7 @@ import numpy
import theano
from theano.compat import all, PY3
from theano.scalar import (ComplexError, IntegerDivisionError,
ScalarConstant, int64)
from theano.scalar import ComplexError, IntegerDivisionError
from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray
......@@ -350,18 +349,7 @@ class _tensor_py_operators:
if not isinstance(args, tuple):
args = args,
# Convert python literals to theano constants
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论