提交 7857774b authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename and refactor subtensor.make_constant

上级 25defabd
......@@ -50,7 +50,7 @@ from theano.tensor.subtensor import (
get_canonical_form_slice,
Subtensor,
IncSubtensor,
make_constant,
as_index_constant,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,
AdvancedSubtensor1,
......@@ -3388,7 +3388,7 @@ def local_subtensor_merge(node):
else:
merged_slices += slices1[pos_1:]
merged_slices = make_constant(merged_slices)
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse(
......
......@@ -51,21 +51,25 @@ class AdvancedBooleanIndexingError(TypeError):
pass
def make_constant(args):
"""Convert Python literals to Theano constants in Subtensor arguments."""
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start), conv(a.stop), conv(a.step))
elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a)
else:
# Use `tensor.scalar_from_tensor`?
return a
def as_index_constant(a):
"""Convert Python literals to Theano constants--when possible--in Subtensor arguments.
return tuple(map(conv, args))
This will leave `Variable`s untouched.
"""
if a is None:
return a
elif isinstance(a, slice):
return slice(
as_index_constant(a.start),
as_index_constant(a.stop),
as_index_constant(a.step),
)
elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a)
elif not isinstance(a, theano.tensor.Variable):
return theano.tensor.as_tensor(a)
else:
return a
def get_idx_list(inputs, idx_list, get_count=False):
......@@ -277,7 +281,9 @@ def range_len(slc):
"""
from theano.tensor import switch, and_, lt, gt
start, stop, step = make_constant([slc.start, slc.stop, slc.step])
start, stop, step = tuple(
as_index_constant(a) for a in [slc.start, slc.stop, slc.step]
)
return switch(
and_(gt(step, 0), lt(start, stop)),
1 + (stop - 1 - start) // step,
......
......@@ -545,14 +545,16 @@ class _tensor_py_operators(object):
# Force input to be int64 datatype if input is an empty list or tuple
# Else leave it as is if it is a real number
# Convert python literals to theano constants
args = tuple(
[
np.array(inp, dtype=np.int64) if (is_empty_array(inp)) else inp
theano.tensor.subtensor.as_index_constant(
np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp
)
for inp in args
]
)
# Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论