提交 51de50be authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract general utility methods from Subtensor class

上级 df2a45d8
...@@ -50,6 +50,7 @@ from aesara.tensor.subtensor import ( ...@@ -50,6 +50,7 @@ from aesara.tensor.subtensor import (
Subtensor, Subtensor,
get_canonical_form_slice, get_canonical_form_slice,
get_idx_list, get_idx_list,
get_slice_elements,
set_subtensor, set_subtensor,
) )
from aesara.tensor.var import TensorConstant, get_unique_value from aesara.tensor.var import TensorConstant, get_unique_value
...@@ -1548,7 +1549,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1548,7 +1549,7 @@ def save_mem_new_scan(fgraph, node):
subtens = Subtensor(nw_slice) subtens = Subtensor(nw_slice)
# slice inputs # slice inputs
sl_ins = Subtensor.collapse( sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable) nw_slice, lambda entry: isinstance(entry, Variable)
) )
new_o = subtens(new_outs[nw_pos], *sl_ins) new_o = subtens(new_outs[nw_pos], *sl_ins)
...@@ -1598,7 +1599,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1598,7 +1599,7 @@ def save_mem_new_scan(fgraph, node):
nw_slice = (sanitize(position),) + tuple(old_slices[1:]) nw_slice = (sanitize(position),) + tuple(old_slices[1:])
subtens = Subtensor(nw_slice) subtens = Subtensor(nw_slice)
sl_ins = Subtensor.collapse( sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable) nw_slice, lambda entry: isinstance(entry, Variable)
) )
new_o = subtens(new_outs[nw_pos], *sl_ins) new_o = subtens(new_outs[nw_pos], *sl_ins)
......
...@@ -417,7 +417,9 @@ def get_scalar_constant_value( ...@@ -417,7 +417,9 @@ def get_scalar_constant_value(
and v.ndim == 0 and v.ndim == 0
): ):
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs)) from aesara.tensor.subtensor import get_constant_idx
cdata = tuple(get_constant_idx(v.owner.op.idx_list, v.owner.inputs))
try: try:
return v.owner.inputs[0].data.__getitem__(cdata).copy() return v.owner.inputs[0].data.__getitem__(cdata).copy()
except IndexError: except IndexError:
......
...@@ -58,7 +58,12 @@ from aesara.tensor.math import sum as aet_sum ...@@ -58,7 +58,12 @@ from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, tensordot, true_div from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.shape import shape, shape_padleft from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
get_constant_idx,
)
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
discrete_dtypes, discrete_dtypes,
...@@ -1736,8 +1741,8 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels): ...@@ -1736,8 +1741,8 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, Subtensor): if isinstance(stop.owner.op, Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if shape_subtensor.op.get_constant_idx( if get_constant_idx(
shape_subtensor.inputs, allow_partial=True shape_subtensor.op.idx_list, shape_subtensor.inputs, allow_partial=True
) == [0]: ) == [0]:
shape_var = shape_subtensor.inputs[0] shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == shape: if shape_var.owner and shape_var.owner.op == shape:
......
差异被折叠。
...@@ -67,7 +67,9 @@ from aesara.tensor.subtensor import ( ...@@ -67,7 +67,9 @@ from aesara.tensor.subtensor import (
as_index_constant, as_index_constant,
as_index_literal, as_index_literal,
get_canonical_form_slice, get_canonical_form_slice,
get_constant_idx,
get_idx_list, get_idx_list,
get_slice_elements,
inc_subtensor, inc_subtensor,
) )
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
...@@ -347,7 +349,7 @@ def local_useless_slice(fgraph, node): ...@@ -347,7 +349,7 @@ def local_useless_slice(fgraph, node):
# check if we removed something # check if we removed something
if last_slice < len(slices): if last_slice < len(slices):
subtens = Subtensor(slices[:last_slice]) subtens = Subtensor(slices[:last_slice])
sl_ins = Subtensor.collapse( sl_ins = get_slice_elements(
slices[:last_slice], lambda x: isinstance(x, Variable) slices[:last_slice], lambda x: isinstance(x, Variable)
) )
out = subtens(node.inputs[0], *sl_ins) out = subtens(node.inputs[0], *sl_ins)
...@@ -518,7 +520,7 @@ def local_subtensor_merge(fgraph, node): ...@@ -518,7 +520,7 @@ def local_subtensor_merge(fgraph, node):
merged_slices = tuple(as_index_constant(s) for s in merged_slices) merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse( sl_ins = get_slice_elements(
merged_slices, lambda x: isinstance(x, Variable) merged_slices, lambda x: isinstance(x, Variable)
) )
# Do not call make_node for test_value # Do not call make_node for test_value
...@@ -766,7 +768,9 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -766,7 +768,9 @@ def local_subtensor_make_vector(fgraph, node):
# The index is a slice. If it's a constant slice, we can perform the # The index is a slice. If it's a constant slice, we can perform the
# index operation here. # index operation here.
try: try:
const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0] const_slice = get_constant_idx(
node.op.idx_list, node.inputs, allow_partial=False
)[0]
ret = make_vector_op(*x.owner.inputs[const_slice]) ret = make_vector_op(*x.owner.inputs[const_slice])
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable) ret = patternbroadcast(ret, node.outputs[0].broadcastable)
...@@ -896,8 +900,11 @@ def local_useless_subtensor(fgraph, node): ...@@ -896,8 +900,11 @@ def local_useless_subtensor(fgraph, node):
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx( cdata = get_constant_idx(
node.inputs, allow_partial=True, only_process_constants=True node.op.idx_list,
node.inputs,
allow_partial=True,
only_process_constants=True,
) )
for pos, idx in enumerate(cdata): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
......
...@@ -526,8 +526,8 @@ class _tensor_py_operators: ...@@ -526,8 +526,8 @@ class _tensor_py_operators:
) )
# Determine if advanced indexing is needed or not. The logic is # Determine if advanced indexing is needed or not. The logic is
# already in `Subtensor.convert`: if it succeeds, standard indexing is # already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with AdvancedIndexingError, advanced indexing is # used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used # used
advanced = False advanced = False
for i, arg in enumerate(args): for i, arg in enumerate(args):
...@@ -537,7 +537,7 @@ class _tensor_py_operators: ...@@ -537,7 +537,7 @@ class _tensor_py_operators:
if arg is not np.newaxis: if arg is not np.newaxis:
try: try:
aet.subtensor.Subtensor.convert(arg) aet.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
if advanced: if advanced:
break break
...@@ -589,7 +589,7 @@ class _tensor_py_operators: ...@@ -589,7 +589,7 @@ class _tensor_py_operators:
else: else:
return aet.subtensor.Subtensor(args)( return aet.subtensor.Subtensor(args)(
self, self,
*aet.subtensor.Subtensor.collapse( *aet.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable) args, lambda entry: isinstance(entry, Variable)
), ),
) )
......
...@@ -23,6 +23,7 @@ from aesara.tensor.math import sum as aet_sum ...@@ -23,6 +23,7 @@ from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIndexingError,
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor, IncSubtensor,
...@@ -35,6 +36,7 @@ from aesara.tensor.subtensor import ( ...@@ -35,6 +36,7 @@ from aesara.tensor.subtensor import (
basic_shape, basic_shape,
get_canonical_form_slice, get_canonical_form_slice,
inc_subtensor, inc_subtensor,
index_vars_to_types,
indexed_result_shape, indexed_result_shape,
set_subtensor, set_subtensor,
take, take,
...@@ -2558,3 +2560,16 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res): ...@@ -2558,3 +2560,16 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
z = tensor3("z") z = tensor3("z")
y = inc_subtensor(x[indices], z, set_instead_of_inc=set_instead_of_inc) y = inc_subtensor(x[indices], z, set_instead_of_inc=set_instead_of_inc)
assert pprint(y) == exp_res assert pprint(y) == exp_res
def test_index_vars_to_types():
x = aet.as_tensor_variable(np.array([True, False]))
with pytest.raises(AdvancedIndexingError):
index_vars_to_types(x)
with pytest.raises(TypeError):
index_vars_to_types(1)
res = index_vars_to_types(iscalar)
assert isinstance(res, scal.Scalar)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论