提交 ccbab653 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move subtensor lift rewrites to their own module

上级 5fa5c9ba
......@@ -15,4 +15,5 @@ import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
import pytensor.tensor.rewriting.subtensor_lift
import pytensor.tensor.rewriting.uncanonicalize
import itertools
import sys
from collections.abc import Iterable
import numpy as np
......@@ -21,11 +20,9 @@ from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import (
Alloc,
Join,
MakeVector,
ScalarFromTensor,
TensorFromScalar,
alloc,
as_tensor,
cast,
concatenate,
get_scalar_constant_value,
......@@ -38,11 +35,8 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import (
Dot,
add,
and_,
ceil_intdiv,
dot,
eq,
ge,
gt,
......@@ -60,11 +54,8 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize,
)
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
shape_padleft,
shape_tuple,
specify_shape,
)
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.subtensor import (
......@@ -78,7 +69,6 @@ from pytensor.tensor.subtensor import (
advanced_subtensor,
advanced_subtensor1,
as_index_constant,
as_index_literal,
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
......@@ -277,64 +267,6 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
return [new_res]
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_dot(fgraph, node):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if not isinstance(node.op, Subtensor):
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(fgraph.clients[node.inputs[0]]) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
idx_list = get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = (
b_indices[: b.ndim - 2]
+ (slice(None, None, None),)
+ b_indices[b.ndim - 2 :]
)
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r = dot(a_sub, b_sub)
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
return [r]
@register_infer_shape
@register_useless
@register_canonicalize
......@@ -420,75 +352,6 @@ def local_useless_slice(fgraph, node):
return [out]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize("fast_compile")
@node_rewriter([Subtensor])
def local_subtensor_lift(fgraph, node):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
"""
if isinstance(node.op, Subtensor):
u = node.inputs[0]
if u.owner is None or len(fgraph.clients[u]) > 1:
return False
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, x_idx)
ret = u.owner.op(x_idx)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, Elemwise):
new_inputs = []
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
# There is no broadcastable in the inputs
idx = node.inputs[1:]
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
# There is no broadcastable in the inputs or it is scalar
idx = node.inputs[1:]
new_inputs = []
for i in u.owner.inputs:
if sum(i.type.broadcastable) == 0:
new_inputs.append(node.op(i, *idx))
else:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if node.outputs[0].ndim == i.ndim:
new_inputs.append(i)
else:
new_inputs.append(
i.dimshuffle(["x"] * node.outputs[0].ndim)
)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
......@@ -619,76 +482,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_alloc(fgraph, node):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if not isinstance(node.op, Subtensor):
return False
u = node.inputs[0]
if u.owner is None:
return False
if not isinstance(u.owner.op, Alloc):
return False
slices = get_idx_list(node.inputs, node.op.idx_list)
val = u.owner.inputs[0]
dims = u.owner.inputs[1:]
assert len(slices) <= len(dims)
# Number of dimensions added to val
n_added_dims = u.ndim - val.ndim
# Dimensions of the returned alloc
nw_dims = []
# Slices to take from val
val_slices = []
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if i >= n_added_dims:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if (
val.type.ndim > (i - n_added_dims)
and val.type.broadcastable[i - n_added_dims]
):
val_slices.append(slice(None))
else:
val_slices.append(sl)
csl, _ = get_canonical_form_slice(sl, dim)
if type(csl) is not slice:
# That dimension is removed.
pass
else:
nw_dim = csl.stop - csl.start
if csl.step != 1:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim = ceil_intdiv(nw_dim, csl.step)
nw_dims += [nw_dim]
nw_val = val[tuple(val_slices)]
nw_dims += dims[len(slices) :]
if nw_val.ndim > len(nw_dims):
return False
rval = alloc(nw_val, *nw_dims)
if not isinstance(rval, list | tuple):
rval = [rval]
return rval
@register_specialize
@register_canonicalize
@node_rewriter([Subtensor])
......@@ -728,91 +521,6 @@ def local_subtensor_inc_subtensor(fgraph, node):
return
@register_infer_shape
@register_specialize
@register_canonicalize("fast_compile")
@register_useless
@node_rewriter([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(fgraph, node):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
Replace all ``Subtensor`` and ``MakeVector`` cases like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
.. note:
This optimization implicitly relies on shape optimizations.
TODO: This only applies to a single indexed dimension; we should have
something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding).
"""
if not isinstance(node.op, Subtensor | AdvancedSubtensor1):
return False
x = node.inputs[0]
if not (x.owner and isinstance(x.owner.op, MakeVector)):
return False
make_vector_op = x.owner.op
if isinstance(node.op, Subtensor):
idxs = node.op.idx_list
# Subtensor has no indexes, return make_vector
if not idxs:
return [x]
(idx,) = idxs
if isinstance(idx, ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if idx.ndim == 0:
try:
v = get_underlying_scalar_constant_value(
idx, only_process_constants=True
)
try:
ret = [x.owner.inputs[v]]
except IndexError:
raise NotScalarConstantError("Bad user graph!")
return ret
except NotScalarConstantError:
pass
elif idx.ndim == 1 and isinstance(idx, Constant):
values = list(map(int, list(idx.value)))
ret = make_vector_op(*[x.owner.inputs[v] for v in values])
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif isinstance(idx, slice):
# The index is a slice. If it's a constant slice, we can perform the
# index operation here.
try:
const_slice = get_constant_idx(
node.op.idx_list, node.inputs, allow_partial=False
)[0]
ret = make_vector_op(*x.owner.inputs[const_slice])
copy_stack_trace(node.outputs, ret)
return [ret]
except NotScalarConstantError:
pass
@register_infer_shape
@register_useless
@register_canonicalize
......@@ -1615,95 +1323,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
return [r]
@register_specialize
@register_canonicalize
@node_rewriter([Subtensor])
def local_subtensor_shape_constant(fgraph, node):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
We want to convert graphs like
Subtensor{int64} [id A] ''
|Shape [id B] ''
| |<TensorType(float64, row)> [id C]
|ScalarConstant{0} [id D]
into
TensorConstant{1}
TODO: Something like `local_shape_to_shape_i` should be a general
canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were
the case, we could change this to only operate on `Shape_i`\s.
Currently, we're not handling them because they should only appear when
`ShapeFeature` is present, and it will also simplify/remove them.
"""
if not isinstance(node.op, Subtensor):
return False
shape = node.inputs[0]
if not (shape.owner and isinstance(shape.owner.op, Shape)):
return False
shape_arg = shape.owner.inputs[0]
(idx,) = get_idx_list(node.inputs, node.op.idx_list)
try:
idx_val = as_index_literal(idx)
except NotScalarConstantError:
return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType):
return False
shape_parts = shape_arg.type.broadcastable[idx_val]
if isinstance(shape_parts, Iterable):
if all(shape_parts):
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
elif shape_parts:
return [as_tensor(1, dtype=np.int64)]
@register_canonicalize
@node_rewriter([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if not isinstance(node.op, Subtensor):
return False
specify_shape_node = node.inputs[0]
if not (
specify_shape_node.owner
and isinstance(specify_shape_node.owner.op, SpecifyShape)
):
return False
obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1:]
indices = get_idx_list(node.inputs, node.op.idx_list)
if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False
new_obj_arg = obj_arg[indices]
# No need to specify shape for scalar outputs
if new_obj_arg.ndim == 0:
return [new_obj_arg]
return [specify_shape(new_obj_arg, shape_arg[len(indices) :])]
@register_specialize
@node_rewriter([Join])
def local_join_subtensors(fgraph, node):
......
from collections.abc import Iterable
import numpy as np
from pytensor import Variable
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
MakeVector,
alloc,
as_tensor,
get_underlying_scalar_constant_value,
register_infer_shape,
)
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, ceil_intdiv, dot
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.subtensor import register_useless
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
specify_shape,
)
from pytensor.tensor.subtensor import (
AdvancedSubtensor1,
Subtensor,
as_index_literal,
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_dot(fgraph, node):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if not isinstance(node.op, Subtensor):
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(fgraph.clients[node.inputs[0]]) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
idx_list = get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = (
b_indices[: b.ndim - 2]
+ (slice(None, None, None),)
+ b_indices[b.ndim - 2 :]
)
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r = dot(a_sub, b_sub)
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
return [r]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize("fast_compile")
@node_rewriter([Subtensor])
def local_subtensor_lift(fgraph, node):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
"""
if isinstance(node.op, Subtensor):
u = node.inputs[0]
if u.owner is None or len(fgraph.clients[u]) > 1:
return False
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, x_idx)
ret = u.owner.op(x_idx)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, Elemwise):
new_inputs = []
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
# There is no broadcastable in the inputs
idx = node.inputs[1:]
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
# There is no broadcastable in the inputs or it is scalar
idx = node.inputs[1:]
new_inputs = []
for i in u.owner.inputs:
if sum(i.type.broadcastable) == 0:
new_inputs.append(node.op(i, *idx))
else:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if node.outputs[0].ndim == i.ndim:
new_inputs.append(i)
else:
new_inputs.append(
i.dimshuffle(["x"] * node.outputs[0].ndim)
)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_alloc(fgraph, node):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if not isinstance(node.op, Subtensor):
return False
u = node.inputs[0]
if u.owner is None:
return False
if not isinstance(u.owner.op, Alloc):
return False
slices = get_idx_list(node.inputs, node.op.idx_list)
val = u.owner.inputs[0]
dims = u.owner.inputs[1:]
assert len(slices) <= len(dims)
# Number of dimensions added to val
n_added_dims = u.ndim - val.ndim
# Dimensions of the returned alloc
nw_dims = []
# Slices to take from val
val_slices = []
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if i >= n_added_dims:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if (
val.type.ndim > (i - n_added_dims)
and val.type.broadcastable[i - n_added_dims]
):
val_slices.append(slice(None))
else:
val_slices.append(sl)
csl, _ = get_canonical_form_slice(sl, dim)
if type(csl) is not slice:
# That dimension is removed.
pass
else:
nw_dim = csl.stop - csl.start
if csl.step != 1:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim = ceil_intdiv(nw_dim, csl.step)
nw_dims += [nw_dim]
nw_val = val[tuple(val_slices)]
nw_dims += dims[len(slices) :]
if nw_val.ndim > len(nw_dims):
return False
rval = alloc(nw_val, *nw_dims)
if not isinstance(rval, list | tuple):
rval = [rval]
return rval
@register_canonicalize
@node_rewriter([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if not isinstance(node.op, Subtensor):
return False
specify_shape_node = node.inputs[0]
if not (
specify_shape_node.owner
and isinstance(specify_shape_node.owner.op, SpecifyShape)
):
return False
obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1:]
indices = get_idx_list(node.inputs, node.op.idx_list)
if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False
new_obj_arg = obj_arg[indices]
# No need to specify shape for scalar outputs
if new_obj_arg.ndim == 0:
return [new_obj_arg]
return [specify_shape(new_obj_arg, shape_arg[len(indices) :])]
@register_infer_shape
@register_specialize
@register_canonicalize("fast_compile")
@register_useless
@node_rewriter([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(fgraph, node):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
Replace all ``Subtensor`` and ``MakeVector`` cases like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
.. note:
This optimization implicitly relies on shape optimizations.
TODO: This only applies to a single indexed dimension; we should have
something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding).
"""
if not isinstance(node.op, Subtensor | AdvancedSubtensor1):
return False
x = node.inputs[0]
if not (x.owner and isinstance(x.owner.op, MakeVector)):
return False
make_vector_op = x.owner.op
if isinstance(node.op, Subtensor):
idxs = node.op.idx_list
# Subtensor has no indexes, return make_vector
if not idxs:
return [x]
(idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if idx.ndim == 0:
try:
v = get_underlying_scalar_constant_value(
idx, only_process_constants=True
)
try:
ret = [x.owner.inputs[v]]
except IndexError:
raise NotScalarConstantError("Bad user graph!")
return ret
except NotScalarConstantError:
pass
elif idx.ndim == 1 and isinstance(idx, Constant):
values = list(map(int, list(idx.value)))
ret = make_vector_op(*[x.owner.inputs[v] for v in values])
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif isinstance(idx, slice):
# The index is a slice. If it's a constant slice, we can perform the
# index operation here.
try:
const_slice = get_constant_idx(
node.op.idx_list, node.inputs, allow_partial=False
)[0]
ret = make_vector_op(*x.owner.inputs[const_slice])
copy_stack_trace(node.outputs, ret)
return [ret]
except NotScalarConstantError:
pass
@register_specialize
@register_canonicalize
@node_rewriter([Subtensor])
def local_subtensor_shape_constant(fgraph, node):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
We want to convert graphs like
Subtensor{int64} [id A] ''
|Shape [id B] ''
| |<TensorType(float64, row)> [id C]
|ScalarConstant{0} [id D]
into
TensorConstant{1}
TODO: Something like `local_shape_to_shape_i` should be a general
canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were
the case, we could change this to only operate on `Shape_i`\s.
Currently, we're not handling them because they should only appear when
`ShapeFeature` is present, and it will also simplify/remove them.
"""
if not isinstance(node.op, Subtensor):
return False
shape = node.inputs[0]
if not (shape.owner and isinstance(shape.owner.op, Shape)):
return False
shape_arg = shape.owner.inputs[0]
(idx,) = get_idx_list(node.inputs, node.op.idx_list)
try:
idx_val = as_index_literal(idx)
except NotScalarConstantError:
return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType):
return False
shape_parts = shape_arg.type.broadcastable[idx_val]
if isinstance(shape_parts, Iterable):
if all(shape_parts):
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
elif shape_parts:
return [as_tensor(1, dtype=np.int64)]
......@@ -9,27 +9,19 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.type import Type
from pytensor.raise_op import Assert
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from pytensor.tensor.basic import Alloc, _convert_to_int8
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, add, dot, exp, sqr
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, exp, sqr
from pytensor.tensor.rewriting.subtensor import (
local_replace_AdvancedSubtensor,
local_subtensor_make_vector,
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import (
SpecifyShape,
_shape,
shape,
specify_shape,
)
from pytensor.tensor.subtensor import (
......@@ -49,10 +41,7 @@ from pytensor.tensor.type import (
dmatrix,
fmatrix,
iscalar,
iscalars,
ivector,
lscalar,
lscalars,
matrix,
scalar,
tensor,
......@@ -60,7 +49,7 @@ from pytensor.tensor.type import (
tensor4,
vector,
)
from pytensor.tensor.type_other import make_slice, slicetype
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param
......@@ -664,262 +653,6 @@ class TestSubtensorIncSubtensor:
assert np.array_equal(f(x_, i_, v_), v_.astype("int8"))
class TestLocalSubtensorMakeVector:
mode = get_mode("FAST_RUN").including("local_subtensor_make_vector")
def test_scalar_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[0], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp)
assert f(0, 1, 2) == 0
def test_idx_symbolic(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
idx = pt.as_tensor([0], dtype=np.int64)
f = function([x, y, z], v[idx], mode=self.mode)
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert f(0, 1, 2) == np.array([0], dtype=np.int32)
def test_slice_idx_start(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore")
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert len(opt_fgraph.outputs[0].owner.inputs) == 2
r = f(0, 1, 2)
assert r[0] == 1 and r[1] == 2
def test_slice_idx_stop(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[:2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 1
def test_slice_idx_step(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[::2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_AdvancedSubtensor1_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[[0, 2]], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_MakeVector_idx(self):
x, y, z, q = lscalars("xyzq")
v = make_vector(x, y, z)
q = make_vector(0, 2)
f = function([x, y, z], v[q], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_stack_trace(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
mode = get_default_mode().including("local_subtensor_make_vector")
# list of subtensor cases, where local_subtensor_make_vector
# inserts a new MakeVector node
v_subtensors = [v[:2], v[::2], v[[0, 2]]]
for v_subtensor in v_subtensors:
f = function([x, y, z], v_subtensor, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
def test_empty_subtensor(self):
x, y = lscalars("xy")
v = make_vector(x, y)
out = v[()]
fgraph = FunctionGraph(outputs=[out], clone=False)
node = fgraph.outputs[0].owner
assert isinstance(node.op, Subtensor)
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
class TestLocalSubtensorLift:
def test_basic(self):
# basic test that the Op works
x = matrix("x")
f = function([x], exp(x)[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) # first subtensor
assert prog[1].op == exp
assert len(prog) == 2
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_1(self):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x = matrix("x")
f = function([x], [exp(x)[0], exp(x)], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor, Elemwise])
prog = f.maker.fgraph.toposort()
assert prog[0].op == exp
assert isinstance(prog[1].op, Subtensor) # first subtensor
assert isinstance(prog[2].op, DeepCopyOp)
assert len(prog) == 3
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_2(self):
# basic test that the optimization work with scalar broadcasted
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_3(self):
# as 1, but take a slice
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_4(self):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y = vector("y")
f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op, Subtensor)
assert prog[2].op == exp
assert len(prog) == 3
f([4, 5]) # let debugmode test something
@utt.assertFailure_fast
def test_basic_5(self):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x = matrix("x")
y = vector("y")
f = function([x, y], exp(x + y)[0], mode=mode_opt)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check='all')
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert prog[1].op == add
assert isinstance(prog[2].op, Subtensor) # first subtensor
assert prog[3].op == inplace.exp_inplace
assert len(prog) == 4
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_6(self):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x = matrix("x")
y = vector("y")
f = function([x, y], [exp(x + y)[0], exp(x + y) + x], mode=mode_opt)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check=Subtensor)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, ps.Composite) # Composite{add,exp}
# first subtensor
assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x = vector("x")
y = scalar("y")
f = function([x, y], exp(x + y)[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=Subtensor)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
# Composite{add,exp}
assert isinstance(prog[1].op.scalar_op, ps.Composite)
assert len(prog) == 2
f([1, 2, 3], 4) # let debugmode test something
class TestLocalSubtensorMerge:
def setup_method(self):
self.x_shapes = [(2, 2), (5, 3), (4, 1), (1, 2), (0, 2), (2, 0), (1, 0), (0, 0)]
......@@ -1803,200 +1536,6 @@ def test_local_set_to_inc_subtensor():
assert check_stack_trace(f2, ops_to_check="all")
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
for s in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]:
x = tensor(
dtype=config.floatX,
shape=(1 if s[0] == 1 else None, 1 if s[1] == 1 else None),
)
xval = np.zeros(s, dtype=config.floatX)
yval = np.arange(s[1], dtype=config.floatX)
for y in [shared(yval), pt.constant([1.0])]:
# The rows of yx are copies of y
yx = pt.alloc(y, x.shape[0], x.shape[1])
# Slice of each row
z_mat = yx[:, 3:]
assert z_mat.ndim == 2
# Only one column
z_vec = yx[:, 3]
assert z_vec.ndim == 1
# results are vector
slicess = []
if s[0] != 1:
slicess.append((2, slice(None)))
if s[1] != 1:
slicess.append((slice(None), 3))
# results are matrix
slicess += [
(slice(None), slice(3, None)),
(slice(3, None),),
(slice(3, None), slice(3, None)),
(slice(1, 3), slice(None, -1)),
(slice(None, None, 2)),
(slice(1, None, 2)),
]
for slices in slicess:
z = yx.__getitem__(slices)
f = function([x], z)
if config.mode != "FAST_COMPILE":
# Subtensor can be in the input of Alloc
assert not isinstance(f.maker.fgraph.toposort()[-1].op, Subtensor)
val = f(xval)
assert xval.__getitem__(slices).shape == val.shape
def test_local_subtensor_shape_constant():
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert res.data == 1
# Make sure it's part of the canonicalizations
res = rewrite_graph(x)
assert isinstance(res, Constant)
assert res.data == 1
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1])
x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1])
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner)
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1,),
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([2, 2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(0,),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites)
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val]))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val]))
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
"x, s, idx",
[
(
matrix(),
(iscalar(), iscalar()),
(slice(1, None),),
),
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 0),
),
],
)
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
@pytest.mark.parametrize(
"axis, slices_fn, expected_nodes",
[
......
import numpy as np
import pytest
import unittest_tools as utt
from pytensor import (
Mode,
Variable,
config,
function,
shared,
)
from pytensor import scalar as ps
from pytensor import tensor as pt
from pytensor.compile import DeepCopyOp, get_default_mode, get_mode
from pytensor.graph import (
Constant,
FunctionGraph,
RewriteDatabaseQuery,
Type,
rewrite_graph,
)
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.tensor import (
add,
exp,
inplace,
iscalar,
iscalars,
lscalar,
lscalars,
matrix,
scalar,
shape,
slicetype,
specify_shape,
tensor,
tensor3,
vector,
)
from pytensor.tensor.basic import MakeVector, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector,
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import SpecifyShape, _shape
from pytensor.tensor.subtensor import Subtensor
mode_opt = config.mode
if mode_opt == "FAST_COMPILE":
mode_opt = "FAST_RUN"
mode_opt = get_mode(mode_opt)
class TestLocalSubtensorLift:
def test_basic(self):
# basic test that the Op works
x = matrix("x")
f = function([x], exp(x)[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) # first subtensor
assert prog[1].op == exp
assert len(prog) == 2
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_1(self):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x = matrix("x")
f = function([x], [exp(x)[0], exp(x)], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor, Elemwise])
prog = f.maker.fgraph.toposort()
assert prog[0].op == exp
assert isinstance(prog[1].op, Subtensor) # first subtensor
assert isinstance(prog[2].op, DeepCopyOp)
assert len(prog) == 3
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_2(self):
# basic test that the optimization work with scalar broadcasted
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_3(self):
# as 1, but take a slice
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_4(self):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y = vector("y")
f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op, Subtensor)
assert prog[2].op == exp
assert len(prog) == 3
f([4, 5]) # let debugmode test something
@utt.assertFailure_fast
def test_basic_5(self):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x = matrix("x")
y = vector("y")
f = function([x, y], exp(x + y)[0], mode=mode_opt)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check='all')
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert prog[1].op == add
assert isinstance(prog[2].op, Subtensor) # first subtensor
assert prog[3].op == inplace.exp_inplace
assert len(prog) == 4
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_6(self):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x = matrix("x")
y = vector("y")
f = function([x, y], [exp(x + y)[0], exp(x + y) + x], mode=mode_opt)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check=Subtensor)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, ps.Composite) # Composite{add,exp}
# first subtensor
assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x = vector("x")
y = scalar("y")
f = function([x, y], exp(x + y)[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=Subtensor)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
# Composite{add,exp}
assert isinstance(prog[1].op.scalar_op, ps.Composite)
assert len(prog) == 2
f([1, 2, 3], 4) # let debugmode test something
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
for s in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]:
x = tensor(
dtype=config.floatX,
shape=(1 if s[0] == 1 else None, 1 if s[1] == 1 else None),
)
xval = np.zeros(s, dtype=config.floatX)
yval = np.arange(s[1], dtype=config.floatX)
for y in [shared(yval), pt.constant([1.0])]:
# The rows of yx are copies of y
yx = pt.alloc(y, x.shape[0], x.shape[1])
# Slice of each row
z_mat = yx[:, 3:]
assert z_mat.ndim == 2
# Only one column
z_vec = yx[:, 3]
assert z_vec.ndim == 1
# results are vector
slicess = []
if s[0] != 1:
slicess.append((2, slice(None)))
if s[1] != 1:
slicess.append((slice(None), 3))
# results are matrix
slicess += [
(slice(None), slice(3, None)),
(slice(3, None),),
(slice(3, None), slice(3, None)),
(slice(1, 3), slice(None, -1)),
(slice(None, None, 2)),
(slice(1, None, 2)),
]
for slices in slicess:
z = yx.__getitem__(slices)
f = function([x], z)
if config.mode != "FAST_COMPILE":
# Subtensor can be in the input of Alloc
assert not isinstance(f.maker.fgraph.toposort()[-1].op, Subtensor)
val = f(xval)
assert xval.__getitem__(slices).shape == val.shape
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1,),
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([2, 2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(0,),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites)
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val]))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val]))
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
"x, s, idx",
[
(
matrix(),
(iscalar(), iscalar()),
(slice(1, None),),
),
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 0),
),
],
)
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
class TestLocalSubtensorMakeVector:
mode = get_mode("FAST_RUN").including("local_subtensor_make_vector")
def test_scalar_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[0], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp)
assert f(0, 1, 2) == 0
def test_idx_symbolic(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
idx = pt.as_tensor([0], dtype=np.int64)
f = function([x, y, z], v[idx], mode=self.mode)
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert f(0, 1, 2) == np.array([0], dtype=np.int32)
def test_slice_idx_start(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore")
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert len(opt_fgraph.outputs[0].owner.inputs) == 2
r = f(0, 1, 2)
assert r[0] == 1 and r[1] == 2
def test_slice_idx_stop(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[:2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 1
def test_slice_idx_step(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[::2], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_AdvancedSubtensor1_idx(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
f = function([x, y, z], v[[0, 2]], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_MakeVector_idx(self):
x, y, z, q = lscalars("xyzq")
v = make_vector(x, y, z)
q = make_vector(0, 2)
f = function([x, y, z], v[q], mode=self.mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_stack_trace(self):
x, y, z = lscalars("xyz")
v = make_vector(x, y, z)
mode = get_default_mode().including("local_subtensor_make_vector")
# list of subtensor cases, where local_subtensor_make_vector
# inserts a new MakeVector node
v_subtensors = [v[:2], v[::2], v[[0, 2]]]
for v_subtensor in v_subtensors:
f = function([x, y, z], v_subtensor, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
def test_empty_subtensor(self):
x, y = lscalars("xy")
v = make_vector(x, y)
out = v[()]
fgraph = FunctionGraph(outputs=[out], clone=False)
node = fgraph.outputs[0].owner
assert isinstance(node.op, Subtensor)
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
def test_local_subtensor_shape_constant():
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert res.data == 1
# Make sure it's part of the canonicalizations
res = rewrite_graph(x)
assert isinstance(res, Constant)
assert res.data == 1
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1])
x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1])
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论