提交 99a994cb authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1757 from abergeron/subtensor_fallout

Add new subtensor helper make_constant and use it to fix a bug in local_subtensor_merge.
...@@ -26,7 +26,8 @@ from theano.gof.utils import MethodNotDefined ...@@ -26,7 +26,8 @@ from theano.gof.utils import MethodNotDefined
from theano.configparser import config from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, AdvancedIncSubtensor1) Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1956,6 +1957,7 @@ def local_subtensor_merge(node): ...@@ -1956,6 +1957,7 @@ def local_subtensor_merge(node):
else: else:
merged_slices += slices1[pos_1:] merged_slices += slices1[pos_1:]
merged_slices = make_constant(merged_slices)
subtens = Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
......
...@@ -47,6 +47,23 @@ class AdvancedIndexingError(TypeError): ...@@ -47,6 +47,23 @@ class AdvancedIndexingError(TypeError):
# Helpful functions to deal with Subtensor and IncSubtensor # Helpful functions to deal with Subtensor and IncSubtensor
########## ##########
def make_constant(args):
"""
Convert python litterals to theano constants in subtensor arguments.
"""
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return scal.ScalarConstant(scal.int64, a)
else:
return a
return tuple(map(conv, args))
def get_idx_list(inputs, idx_list): def get_idx_list(inputs, idx_list):
''' '''
Given a list of inputs to the subtensor and its idx_list reorders Given a list of inputs to the subtensor and its idx_list reorders
......
...@@ -4,8 +4,7 @@ import numpy ...@@ -4,8 +4,7 @@ import numpy
import theano import theano
from theano.compat import all, PY3 from theano.compat import all, PY3
from theano.scalar import (ComplexError, IntegerDivisionError, from theano.scalar import ComplexError, IntegerDivisionError
ScalarConstant, int64)
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
...@@ -350,18 +349,7 @@ class _tensor_py_operators: ...@@ -350,18 +349,7 @@ class _tensor_py_operators:
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
# Convert python literals to theano constants # Convert python literals to theano constants
def conv(a): args = theano.tensor.subtensor.make_constant(args)
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds, # The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论