提交 f25a624a authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Ricardo Vieira
上级 23427a0a
......@@ -4,6 +4,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
# Load dispatch specializations
import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.einsum
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
......
import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.einsum import Einsum
@jax_funcify.register(Einsum)
def jax_funcify_Einsum(op, **kwargs):
"""Dispatch einsum to JAX.
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
as it always compiles a function per runtime input shape.
"""
subscripts = op.subscripts
def einsum(*operands):
return jnp.einsum(subscripts, *operands, optimize="optimal")
return einsum
......@@ -151,6 +151,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable
# isort: off
from pytensor.tensor.einsum import einsum
from pytensor.tensor.functional import vectorize
# isort: on
......
......@@ -1700,21 +1700,22 @@ class Alloc(COp):
return False
for client, idx in clients:
if isinstance(client.op, Output):
client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Op's through which Alloc can be lifted
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return False
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return False
elif (
# The following ops work inplace of their input id 0.
idx == 0
and isinstance(
client.op,
client_op,
pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor
......@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
_x = as_tensor_variable(x)
if axes is None:
axes = list(range((_x.type.ndim - 1), -1, -1))
axes = tuple(range((_x.type.ndim - 1), -1, -1))
if tuple(axes) == tuple(range(len(axes))):
# No-op
return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"
return ret
......@@ -3950,6 +3956,10 @@ def moveaxis(
source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination")
if source == destination:
# It's a no-op
return a
if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have the same number of elements"
......@@ -4260,9 +4270,7 @@ atleast_2d = partial(atleast_Nd, n=2)
atleast_3d = partial(atleast_Nd, n=3)
def expand_dims(
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
......@@ -4281,7 +4289,7 @@ def expand_dims(
"""
a = as_tensor(a)
if not isinstance(axis, tuple | list):
if not isinstance(axis, Sequence):
axis = (axis,)
out_ndim = len(axis) + a.ndim
......
差异被折叠。
from collections.abc import Callable
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
def vectorize(func: Callable, signature: str | None = None) -> Callable:
......
......@@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops
# Register JAX specializations
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
......
......@@ -52,6 +52,7 @@ from pytensor.tensor.basic import (
TensorFromScalar,
alloc,
as_tensor_variable,
atleast_Nd,
cast,
extract_constant,
fill,
......@@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node):
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node):
"""
Lift DimShuffle through Alloc
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
alloc_out = node.inputs[0]
alloc_node = alloc_out.owner
if not (alloc_node and isinstance(alloc_node.op, Alloc)):
return
ds_op = node.op
value, *alloc_shape = alloc_node.inputs
# Add implicit dimensions of value
value = atleast_Nd(value, n=len(alloc_shape))
# Dimshuffle value and alloc_shape
ds_value = value.dimshuffle(ds_op.new_order)
ds_alloc_shape = [alloc_shape[i] for i in ds_op.shuffle]
for dim in ds_op.augment:
ds_alloc_shape.insert(dim, 1)
return [alloc(ds_value, *ds_alloc_shape)]
@register_specialize("shape_unsafe")
@node_rewriter([Join])
def local_join_of_alloc(fgraph, node):
"""Rewrite a Join of Alloc nodes to an Alloc of the Join nodes."""
axis, *tensors = node.inputs
if len(tensors) < 2:
# Let other rewrite handle the useless Join
return
if not isinstance(axis, Constant):
return
core_tensors = []
alloc_shapes = []
for tensor in tensors:
if tensor.owner is None:
return
# tensor = expand_dims_to_alloc(tensor)
if not isinstance(tensor.owner.op, Alloc):
return
value, *shape = tensor.owner.inputs
# Introduce explicit batch dims
value = atleast_Nd(value, n=len(shape))
core_tensors.append(value)
alloc_shapes.append(shape)
# Find which allocated dimensions can be lifted
# Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs
axis = axis.data
broadcasted_dims = list(
zip(
*(
[
bef and not aft
for bef, aft in zip(
core_tensor.type.broadcastable,
tensor.type.broadcastable,
strict=True,
)
]
for core_tensor, tensor in zip(core_tensors, tensors, strict=True)
)
)
)
lifteable_alloc_dims = {
dim
for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim])
}
if not lifteable_alloc_dims:
return
# Lift the allocated dimensions
new_tensors = []
for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes):
pre_join_shape = [
1 if i in lifteable_alloc_dims else alloc_dim
for i, alloc_dim in enumerate(alloc_shape)
]
new_tensor = alloc(core_tensor, *pre_join_shape)
copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims
post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes)):
if i == axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims))
else:
# Otherwise the shapes should all match. We prioritize constants if any
for best_alloc_dim in alloc_dims:
if isinstance(best_alloc_dim, Constant):
break
post_join_shape.append(best_alloc_dim)
new_out = alloc(new_join, *post_join_shape)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
......@@ -10,6 +10,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
......@@ -67,10 +68,16 @@ optdb.register(
def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance(
node.op.core_op,
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor,
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectored at runtime,
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
......@@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node):
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
"""
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None
op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)
if not batch_ndim:
return None
if not any(var.owner and isinstance(var.owner.op, Alloc) for var in node.inputs):
return None
new_inputs = []
batch_shapes = []
can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, Alloc):
# Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
):
# We still need an Alloc for the core dims
core_shape = shape[batch_ndim:]
# And the batch dims of the squeezed value
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape)
batch_shape = [
1 if broadcastable else dim
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
if not all(inp.type.broadcastable[:batch_ndim]):
if inp.owner and isinstance(inp.owner.op, Alloc):
# Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
):
# We still need an Alloc for the core dims
core_shape = shape[batch_ndim:]
# And the batch dims of the squeezed value
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(
core_shape
)
]
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
if squeezed_value.type.broadcastable == inp.type.broadcastable:
# We can't change anything about this Alloc input
new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
batch_shape = [
1 if broadcastable else dim
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[
:squeezed_value_batch_ndim
],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
)
]
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
if squeezed_value.type.broadcastable == inp.type.broadcastable:
# We can't change anything about this Alloc input
new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
)
)
)
)
new_inputs.append(squeezed_value)
can_push_any_alloc = True
new_inputs.append(squeezed_value)
can_push_any_alloc = True
continue
else:
# Nothing to do with this input other than removing dummy batch dims
new_inputs.append(_squeeze_left(inp, batch_ndim))
# Nothing to do with this input other than removing dummy batch dims
new_inputs.append(_squeeze_left(inp, batch_ndim))
if not can_push_any_alloc:
return None
......@@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node):
missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
if old_out_type.broadcastable[i]:
continue
for batch_dim in batch_dims:
if batch_dim == 1:
continue
batch_shape[i] = batch_dim
if isinstance(batch_dim, Constant):
# Give preference to Constants
batch_shape[i] = batch_dim
break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim
copy_stack_trace(node.outputs, new_outs)
new_outs = [
......@@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node):
]
copy_stack_trace(node.outputs, new_outs)
return new_outs
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
For the square Reshape case, we must wait for all the intemediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
batched_shape = x.shape[:batch_ndim]
core_reshape = _squeeze_left(output_shape, batch_ndim)
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
from typing import cast
from pytensor.graph import Apply, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.einsum import Einsum, einsum
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.ofg import inline_ofg_node
from pytensor.tensor.variable import TensorVariable
@register_specialize
@node_rewriter([Einsum])
def optimize_einsum_inner_graph(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Try to optimize an einsum that was not optimizable at definition time.
This can happen when users replace a graph without rebuilding
Or when during the course of rewrites more specialized static shapes are found
"""
op: Einsum = node.op
if op.optimized:
# Already optimized
return None
operands = node.inputs
if any(None in operand.type.shape for operand in operands):
return None
new_out = einsum(op.subscripts, *operands)
assert new_out.owner.op.optimized
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@register_specialize
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op
if not op.optimized:
return None
return cast(list[TensorVariable], inline_ofg_node(node))
from pytensor import clone_replace
from typing import cast
from pytensor import Variable, clone_replace
from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph import Apply, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize
def inline_ofg_node(node: Apply) -> list[Variable]:
op = node.op
assert isinstance(op, OpFromGraph)
inlined_outs = clone_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))
)
copy_stack_trace(op.inner_outputs, inlined_outs)
return cast(list[Variable], inlined_outs)
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
......@@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node):
if not op.is_inline:
return False
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
return inline_ofg_node(node)
# We want to run this before the first merge optimizer
......@@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node):
-------
"""
op = node.op
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
return inline_ofg_node(node)
......@@ -749,51 +749,43 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op):
@node_rewriter([op])
def f(fgraph, node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False
return f
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Reshape])
def local_reshape_chain(fgraph, node):
"""
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
"""
if not check_chain(node, Reshape, Reshape):
return False
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
@register_useless
@register_canonicalize
@register_stabilize
@register_useless("shape_unsafe")
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Reshape])
def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`.
......@@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node):
- Remove `Reshape` when reshaping to the shape of the input.
"""
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
inp, output_shape = node.inputs
[output] = node.outputs
if inp.type.ndim != output.type.ndim:
return False
# Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if (
inp.type.ndim == 1
and output.type.ndim == 1
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
and inp.type.broadcastable == output.type.broadcastable
):
return [inp]
......@@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node):
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
if isinstance(output_shape, Constant) or (
output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
):
if isinstance(output_shape, Constant):
output_shape_is = [
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
]
else:
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
......@@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True
continue
# Match 1 if input.type.shape[dim] == 1
# Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True
continue
......@@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node):
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
extract_constant(inpshp_i, only_process_constants=True)
== extract_constant(outshp_i, only_process_constants=True)
):
shape_match[dim] = True
continue
if all(shape_match) and nb_m1 <= 1:
if nb_m1 <= 1 and all(shape_match):
return [inp]
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False
......@@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
"""
op = node.op
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
inp, output_shape = node.inputs
[output] = node.outputs
dimshuffle_new_order = []
new_output_shape = []
......@@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Reshape])
def local_reshape_lift(fgraph, node):
"""
......
......@@ -842,13 +842,13 @@ class Reshape(COp):
@_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape):
from pytensor.tensor.blockwise import vectorize_node_fallback
old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim
if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError(
"It is not possible to vectorize the shape argument of Reshape"
)
return vectorize_node_fallback(op, node, x, shape)
if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape]
......
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
jax = pytest.importorskip("jax")
def test_jax_einsum():
subscripts = "ij, jk, kl -> il"
x = np.random.rand(3, 5)
y = np.random.rand(5, 2)
z = np.random.rand(2, 4)
shapes = ((3, 5), (5, 2), (2, 4))
x_pt, y_pt, z_pt = (
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
)
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
@pytest.mark.xfail(raises=NotImplementedError)
def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
y = np.random.rand(2, 5)
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
from functools import partial
from pytensor import function
from pytensor.graph import FunctionGraph, rewrite_graph
import numpy as np
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
......@@ -9,6 +11,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
from pytensor.tensor.shape import Reshape
def test_useless_blockwise_of_elemwise():
......@@ -45,7 +48,7 @@ def test_blockwise_alloc():
rewrite = partial(
rewrite_graph,
include=("ShapeOpt", "specialize"),
exclude=("local_useless_unbatched_blockwise",),
exclude=("local_useless_unbatched_blockwise", "local_dimshuffle_alloc"),
)
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
......@@ -104,7 +107,9 @@ def test_blockwise_alloc():
y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
assert equal(
[rewrite(out)], [expected_out]
), None # pytensor.dprint([expected_out, rewrite(out)], print_type=True)
x = tensor("x", shape=(5,))
y = tensor("y", shape=())
......@@ -118,3 +123,27 @@ def test_blockwise_alloc():
out = vector_add(x, alloc(y, 5))
expected_out = out
assert equal([rewrite(out)], [expected_out])
def test_blockwise_reshape():
x = tensor("x", shape=(None, None, None))
y = x.reshape([x.shape[0] * x.shape[1], -1])
new_x = tensor("x", shape=(None, None, None, None))
new_y = vectorize_graph(y, {x: new_x})
assert not isinstance(new_y.owner.op, Reshape)
assert isinstance(new_y.owner.op, Blockwise) and isinstance(
new_y.owner.op.core_op, Reshape
)
rewritten_y = rewrite_graph(
new_y, include=("canonicalize", "specialize"), clone=True
)
assert isinstance(rewritten_y.owner.op, Reshape)
no_rewrites = Mode(linker="py", optimizer=None)
test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2).astype(config.floatX)
np.testing.assert_allclose(
new_y.eval({"x": test_x}, mode=no_rewrites),
rewritten_y.eval({"x": test_x}, mode=no_rewrites),
)
from functools import partial
from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import einsum, specify_shape, tensor
from pytensor.tensor.einsum import Einsum
specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True)
def test_einsum_optimization():
a = tensor("a", shape=(None, None))
b = tensor("b", shape=(None, None))
c = tensor("c", shape=(None, None))
dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c)
assert not dynamic_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(dynamic_shape_einsum)
assert isinstance(rewritten_out.owner.op, Einsum)
a = specify_shape(a, (2, 3))
b = specify_shape(b, (2, 3))
c = specify_shape(c, (3, 5))
static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs(
[a, b, c]
).default_output()
assert not static_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(static_shape_einsum)
# Einsum was inlined because it was optimized
assert not isinstance(rewritten_out.owner.op, Einsum)
# Sanity check that it's not buried in the graph
assert not any(
isinstance(var.owner.op, Einsum)
for var in ancestors([rewritten_out])
if var.owner
)
......@@ -337,6 +337,52 @@ class TestLocalUselessReshape:
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
def test_constant_shape(self):
# Where reshape is a constant that matches the shape
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
# This could be rewritten as a specify_shape(x, (2, 3))
assert new_out is not x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([3, 2]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is not x
def test_all_but_one_match(self):
x = matrix(shape=(None, None))
shape = [x.shape[0], 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert equal_computations([new_out], [specify_shape(x, (None, 3))])
# Rewrite does not apply if there's also a -1
shape = [-1, 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is out
class TestLocalReshapeToDimshuffle:
def setup_method(self):
......
......@@ -3847,8 +3847,10 @@ def test_transpose():
assert np.all(t2d == np.transpose(x2v, [0, 1]))
assert np.all(t3d == np.transpose(x3v, [0, 2, 1]))
# Check we don't introduce useless transpose
assert ptb.transpose(x1) is x1
# Check that we create a name.
assert ptb.transpose(x1).name == "x1.T"
assert ptb.transpose(x2).name == "x2.T"
assert ptb.transpose(x3).name == "x3.T"
assert ptb.transpose(dmatrix()).name is None
......
from functools import partial
from string import ascii_lowercase
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape
# Fail for unexpected warnings in this file
pytestmark = pytest.mark.filterwarnings("error")
floatX = pytensor.config.floatX
ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
for node in fgraph.apply_nodes:
if isinstance(node.op, Blockwise):
if core_op is None:
raise AssertionError
assert not isinstance(node.op.core_op, core_op)
if isinstance(node.op, HasInnerGraph):
# InnerGraph Ops can be rewritten without modifying the original fgraph
if hasattr(node.op, "_fn"):
inner_fgraph = node.op._fn.maker.fgraph
else:
inner_fgraph = node.op.fgraph
assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op)
def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)
np.testing.assert_allclose(
_iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)
def test_delta():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_delta((2, 2), (0, 1)).eval(mode=mode),
[[1.0, 0.0], [0.0, 1.0]],
)
np.testing.assert_allclose(
_delta((2, 2, 2), (0, 1)).eval(mode=mode),
[[[1, 1], [0, 0]], [[0, 0], [1, 1]]],
)
def test_general_dot():
rng = np.random.default_rng(45)
signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)"
tensordot_axes = [(-3, -2), (-1, -4)]
# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
fn = pytensor.function([x, y], out)
# fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph, Reshape)
np_batched_tensordot = np.vectorize(
partial(np.tensordot, axes=tensordot_axes), signature=signature
)
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
np.testing.assert_allclose(
fn(x_test, y_test), np_batched_tensordot(x_test, y_test), atol=ATOL, rtol=RTOL
)
@pytest.mark.parametrize("static_shape_known", [True, False])
@pytest.mark.parametrize(
"signature",
[
"ij",
"ji",
"ii->i",
"ii",
"ij->",
"ij->j",
"ij->i",
"ij,ij->ij",
"ij,ji->ij",
"ij,ji->ji",
"ij,jk",
"kj,ji",
"ij,kj->ik",
"ik,kj->ikj",
"ij,kl->ijkl",
"ij,jk,kl->il",
"kl,ij,jk->il",
"oij,imj,mjkn,lnk,plk->op",
],
)
def test_einsum_signatures(static_shape_known, signature):
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
inputs = signature.split("->")[0].split(",")
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
if static_shape_known:
static_shapes = shapes
else:
static_shapes = [[None] * len(shape) for shape in shapes]
operands = [
pt.tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes)
]
out = pt.einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
rng = np.random.default_rng(37)
test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes]
np_out = np.einsum(signature, *test_values)
fn = function(operands, out)
pt_out = fn(*test_values)
# print(); fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph)
np.testing.assert_allclose(pt_out, np_out, atol=ATOL, rtol=RTOL)
def test_batch_dim():
shapes = (
(7, 3, 5),
(5, 2),
)
x, y = (pt.tensor(name, shape=shape) for name, shape in zip("xy", shapes))
out = pt.einsum("mij,jk->mik", x, y)
assert out.type.shape == (7, 3, 2)
def test_einsum_conv():
# Adapted example from https://medium.com/latinxinai/vectorized-convolution-operation-using-numpy-b122fd52fba3
rng = np.random.default_rng(125)
batch_size = 32
channels = 3
height = 8
width = 8
kernel_size = 2
num_filters = 15
conv_signature = "bchwkt,fckt->bfhw"
windowed_input = rng.random(
size=(batch_size, channels, height, width, kernel_size, kernel_size)
).astype(floatX)
weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size)).astype(
floatX
)
result = einsum(conv_signature, windowed_input, weights).eval()
assert result.shape == (32, 15, 8, 8)
np.testing.assert_allclose(
result,
np.einsum("bchwkt,fckt->bfhw", windowed_input, weights),
atol=ATOL,
rtol=RTOL,
)
def test_ellipsis():
rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test)
with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y)
out = pt.einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
)
# Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1),
atol=ATOL,
rtol=RTOL,
)
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
)
def test_broadcastable_dims():
# Test that einsum handles broadcasting dims correctly. There are two points:
# 1. Numpy einsum allows the same subscript for degenerate and full dimensions
# There is some stale discussion on whether this should be a bug or not, but for now it is not:
# https://github.com/numpy/numpy/issues/11548
# 2. Using the same letter for dimensions that are and aren't broadcastable
# can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32))
a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX)
c_test = rng.normal(size=c.type.shape).astype(floatX)
# Note b is used for both 1 and 32
with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path"
):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
# If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
optimal_eval = optimal_out.eval({a: a_test, b: b_test, c: c_test})
np_eval = np.einsum("ijk,bj,bk->i", a_test, b_test, c_test)
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
......@@ -14,7 +14,7 @@ from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, as_tensor, constant
from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
......@@ -801,8 +801,14 @@ class TestVectorize:
[vect_out] = vectorize_node(node, mat, new_shape).outputs
assert equal_computations([vect_out], [reshape(mat, new_shape)])
with pytest.raises(NotImplementedError):
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))
new_shape = stack([[-1, x], [x - 1, -1]], axis=0)
print(new_shape.type)
[vect_out] = vectorize_node(node, vec, new_shape).outputs
vec_test_value = np.arange(6)
np.testing.assert_allclose(
vect_out.eval({x: 3, vec: vec_test_value}),
np.broadcast_to(vec_test_value.reshape(2, 3), (2, 2, 3)),
)
with pytest.raises(
ValueError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论