提交 df2a45d8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary exceptions and update comments in local_subtensor_make_vector

Closes #97
上级 eda01ce9
...@@ -704,53 +704,51 @@ def local_subtensor_inc_subtensor(fgraph, node): ...@@ -704,53 +704,51 @@ def local_subtensor_inc_subtensor(fgraph, node):
@register_useless @register_useless
@local_optimizer([Subtensor, AdvancedSubtensor1]) @local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(fgraph, node): def local_subtensor_make_vector(fgraph, node):
""" """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
Replace all subtensor(make_vector) like:
[a,b,c][0] -> a Replace all ``Subtensor`` and ``MakeVector`` cases like:
[a,b,c][0:2] -> [a,b] [a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all AdvancedSubtensor1(make_vector) like: Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like:
[a,b,c][[0,2]] -> [a,c] [a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes. We can do this for constant indexes.
.. note:
This optimization implicitly relies on shape optimizations.
TODO: This only applies to a single indexed dimension; we should have
something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding).
""" """
if not isinstance(node.op, (Subtensor, AdvancedSubtensor1)):
return False
x = node.inputs[0] x = node.inputs[0]
if not x.owner or not isinstance(x.owner.op, MakeVector): if not x.owner or not isinstance(x.owner.op, MakeVector):
return False return False
make_vector_op = x.owner.op make_vector_op = x.owner.op
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature (idx,) = node.op.idx_list
try:
(idx,) = node.op.idx_list
except Exception:
# 'how can you have multiple indexes into a shape?'
raise
if isinstance(idx, (aes.Scalar, TensorType)): if isinstance(idx, (aes.Scalar, TensorType)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert idx.type == old_idx assert idx.type == old_idx
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1] idx = node.inputs[1]
else:
return
if isinstance(idx, (int, np.integer)): if isinstance(idx, (int, np.integer)):
# We don't need to copy over any stack traces here
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
elif isinstance(idx, Variable): elif isinstance(idx, Variable):
if idx.ndim == 0: if idx.ndim == 0:
# if it is a constant we can do something with it
try: try:
v = get_scalar_constant_value(idx, only_process_constants=True) v = get_scalar_constant_value(idx, only_process_constants=True)
if isinstance(v, np.integer):
# Python 2.4 wants to index only with Python integers
v = int(v)
# We don't need to copy over any stack traces here
try: try:
ret = [x.owner.inputs[v]] ret = [x.owner.inputs[v]]
except IndexError: except IndexError:
...@@ -761,28 +759,20 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -761,28 +759,20 @@ def local_subtensor_make_vector(fgraph, node):
elif idx.ndim == 1 and isinstance(idx, Constant): elif idx.ndim == 1 and isinstance(idx, Constant):
values = list(map(int, list(idx.value))) values = list(map(int, list(idx.value)))
ret = make_vector_op(*[x.owner.inputs[v] for v in values]) ret = make_vector_op(*[x.owner.inputs[v] for v in values])
# Copy over stack trace from previous output to new output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable) ret = patternbroadcast(ret, node.outputs[0].broadcastable)
return [ret] return [ret]
else:
raise TypeError("case not expected")
elif isinstance(idx, slice): elif isinstance(idx, slice):
# it is a slice of ints and/or Variables # The index is a slice. If it's a constant slice, we can perform the
# check subtensor to see if it can contain constant variables, and if # index operation here.
# it can, then try to unpack them.
try: try:
const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0] const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0]
ret = make_vector_op(*x.owner.inputs[const_slice]) ret = make_vector_op(*x.owner.inputs[const_slice])
# Copy over stack trace from previous outputs to new output
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable) ret = patternbroadcast(ret, node.outputs[0].broadcastable)
return [ret] return [ret]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else:
raise TypeError("case not expected")
@register_useless @register_useless
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论