提交 7857774b authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename and refactor subtensor.make_constant

上级 25defabd
...@@ -50,7 +50,7 @@ from theano.tensor.subtensor import ( ...@@ -50,7 +50,7 @@ from theano.tensor.subtensor import (
get_canonical_form_slice, get_canonical_form_slice,
Subtensor, Subtensor,
IncSubtensor, IncSubtensor,
make_constant, as_index_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
...@@ -3388,7 +3388,7 @@ def local_subtensor_merge(node): ...@@ -3388,7 +3388,7 @@ def local_subtensor_merge(node):
else: else:
merged_slices += slices1[pos_1:] merged_slices += slices1[pos_1:]
merged_slices = make_constant(merged_slices) merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse( sl_ins = Subtensor.collapse(
......
...@@ -51,21 +51,25 @@ class AdvancedBooleanIndexingError(TypeError): ...@@ -51,21 +51,25 @@ class AdvancedBooleanIndexingError(TypeError):
pass pass
def make_constant(args): def as_index_constant(a):
"""Convert Python literals to Theano constants in Subtensor arguments.""" """Convert Python literals to Theano constants--when possible--in Subtensor arguments.
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start), conv(a.stop), conv(a.step))
elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a)
else:
# Use `tensor.scalar_from_tensor`?
return a
return tuple(map(conv, args)) This will leave `Variable`s untouched.
"""
if a is None:
return a
elif isinstance(a, slice):
return slice(
as_index_constant(a.start),
as_index_constant(a.stop),
as_index_constant(a.step),
)
elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a)
elif not isinstance(a, theano.tensor.Variable):
return theano.tensor.as_tensor(a)
else:
return a
def get_idx_list(inputs, idx_list, get_count=False): def get_idx_list(inputs, idx_list, get_count=False):
...@@ -277,7 +281,9 @@ def range_len(slc): ...@@ -277,7 +281,9 @@ def range_len(slc):
""" """
from theano.tensor import switch, and_, lt, gt from theano.tensor import switch, and_, lt, gt
start, stop, step = make_constant([slc.start, slc.stop, slc.step]) start, stop, step = tuple(
as_index_constant(a) for a in [slc.start, slc.stop, slc.step]
)
return switch( return switch(
and_(gt(step, 0), lt(start, stop)), and_(gt(step, 0), lt(start, stop)),
1 + (stop - 1 - start) // step, 1 + (stop - 1 - start) // step,
......
...@@ -545,14 +545,16 @@ class _tensor_py_operators(object): ...@@ -545,14 +545,16 @@ class _tensor_py_operators(object):
# Force input to be int64 datatype if input is an empty list or tuple # Force input to be int64 datatype if input is an empty list or tuple
# Else leave it as is if it is a real number # Else leave it as is if it is a real number
# Convert python literals to theano constants
args = tuple( args = tuple(
[ [
np.array(inp, dtype=np.int64) if (is_empty_array(inp)) else inp theano.tensor.subtensor.as_index_constant(
np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp
)
for inp in args for inp in args
] ]
) )
# Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds, # The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论