提交 f25a624a authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Ricardo Vieira
上级 23427a0a
...@@ -4,6 +4,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify ...@@ -4,6 +4,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
# Load dispatch specializations # Load dispatch specializations
import pytensor.link.jax.dispatch.blas import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.einsum
import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad import pytensor.link.jax.dispatch.pad
......
import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.einsum import Einsum
@jax_funcify.register(Einsum)
def jax_funcify_Einsum(op, **kwargs):
"""Dispatch einsum to JAX.
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
as it always compiles a function per runtime input shape.
"""
subscripts = op.subscripts
def einsum(*operands):
return jnp.einsum(subscripts, *operands, optimize="optimal")
return einsum
...@@ -151,6 +151,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable ...@@ -151,6 +151,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable
# isort: off # isort: off
from pytensor.tensor.einsum import einsum
from pytensor.tensor.functional import vectorize from pytensor.tensor.functional import vectorize
# isort: on # isort: on
......
...@@ -1700,21 +1700,22 @@ class Alloc(COp): ...@@ -1700,21 +1700,22 @@ class Alloc(COp):
return False return False
for client, idx in clients: for client, idx in clients:
if isinstance(client.op, Output): client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied # If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold. # each time the function is called. So we do not fold.
return False return False
# Allow alloc to be lifted out of Elemwise before constant folding it # Op's through which Alloc can be lifted
elif isinstance(client.op, Elemwise): elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return None return False
# Same for Blockwise, unless it has no batch_dims # Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client): elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return None return False
elif ( elif (
# The following ops work inplace of their input id 0. # The following ops work inplace of their input id 0.
idx == 0 idx == 0
and isinstance( and isinstance(
client.op, client_op,
pytensor.tensor.subtensor.IncSubtensor pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1 | pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor | pytensor.tensor.subtensor.AdvancedIncSubtensor
...@@ -2035,10 +2036,15 @@ def transpose(x, axes=None): ...@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
_x = as_tensor_variable(x) _x = as_tensor_variable(x)
if axes is None: if axes is None:
axes = list(range((_x.type.ndim - 1), -1, -1)) axes = tuple(range((_x.type.ndim - 1), -1, -1))
if tuple(axes) == tuple(range(len(axes))):
# No-op
return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)): if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T" ret.name = _x.name + ".T"
return ret return ret
...@@ -3950,6 +3956,10 @@ def moveaxis( ...@@ -3950,6 +3956,10 @@ def moveaxis(
source = normalize_axis_tuple(source, a.ndim, "source") source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination") destination = normalize_axis_tuple(destination, a.ndim, "destination")
if source == destination:
# It's a no-op
return a
if len(source) != len(destination): if len(source) != len(destination):
raise ValueError( raise ValueError(
"`source` and `destination` arguments must have the same number of elements" "`source` and `destination` arguments must have the same number of elements"
...@@ -4260,9 +4270,7 @@ atleast_2d = partial(atleast_Nd, n=2) ...@@ -4260,9 +4270,7 @@ atleast_2d = partial(atleast_Nd, n=2)
atleast_3d = partial(atleast_Nd, n=3) atleast_3d = partial(atleast_Nd, n=3)
def expand_dims( def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
"""Expand the shape of an array. """Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded Insert a new axis that will appear at the `axis` position in the expanded
...@@ -4281,7 +4289,7 @@ def expand_dims( ...@@ -4281,7 +4289,7 @@ def expand_dims(
""" """
a = as_tensor(a) a = as_tensor(a)
if not isinstance(axis, tuple | list): if not isinstance(axis, Sequence):
axis = (axis,) axis = (axis,)
out_ndim = len(axis) + a.ndim out_ndim = len(axis) + a.ndim
......
差异被折叠。
from collections.abc import Callable from collections.abc import Callable
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
def vectorize(func: Callable, signature: str | None = None) -> Callable: def vectorize(func: Callable, signature: str | None = None) -> Callable:
......
...@@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas ...@@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops import pytensor.tensor.rewriting.extra_ops
# Register JAX specializations
import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
......
...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import (
TensorFromScalar, TensorFromScalar,
alloc, alloc,
as_tensor_variable, as_tensor_variable,
atleast_Nd,
cast, cast,
extract_constant, extract_constant,
fill, fill,
...@@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node): ...@@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node):
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node):
"""
Lift DimShuffle through Alloc
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
alloc_out = node.inputs[0]
alloc_node = alloc_out.owner
if not (alloc_node and isinstance(alloc_node.op, Alloc)):
return
ds_op = node.op
value, *alloc_shape = alloc_node.inputs
# Add implicit dimensions of value
value = atleast_Nd(value, n=len(alloc_shape))
# Dimshuffle value and alloc_shape
ds_value = value.dimshuffle(ds_op.new_order)
ds_alloc_shape = [alloc_shape[i] for i in ds_op.shuffle]
for dim in ds_op.augment:
ds_alloc_shape.insert(dim, 1)
return [alloc(ds_value, *ds_alloc_shape)]
@register_specialize("shape_unsafe")
@node_rewriter([Join])
def local_join_of_alloc(fgraph, node):
"""Rewrite a Join of Alloc nodes to an Alloc of the Join nodes."""
axis, *tensors = node.inputs
if len(tensors) < 2:
# Let other rewrite handle the useless Join
return
if not isinstance(axis, Constant):
return
core_tensors = []
alloc_shapes = []
for tensor in tensors:
if tensor.owner is None:
return
# tensor = expand_dims_to_alloc(tensor)
if not isinstance(tensor.owner.op, Alloc):
return
value, *shape = tensor.owner.inputs
# Introduce explicit batch dims
value = atleast_Nd(value, n=len(shape))
core_tensors.append(value)
alloc_shapes.append(shape)
# Find which allocated dimensions can be lifted
# Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs
axis = axis.data
broadcasted_dims = list(
zip(
*(
[
bef and not aft
for bef, aft in zip(
core_tensor.type.broadcastable,
tensor.type.broadcastable,
strict=True,
)
]
for core_tensor, tensor in zip(core_tensors, tensors, strict=True)
)
)
)
lifteable_alloc_dims = {
dim
for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim])
}
if not lifteable_alloc_dims:
return
# Lift the allocated dimensions
new_tensors = []
for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes):
pre_join_shape = [
1 if i in lifteable_alloc_dims else alloc_dim
for i, alloc_dim in enumerate(alloc_shape)
]
new_tensor = alloc(core_tensor, *pre_join_shape)
copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims
post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes)):
if i == axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims))
else:
# Otherwise the shapes should all match. We prioritize constants if any
for best_alloc_dim in alloc_dims:
if isinstance(best_alloc_dim, Constant):
break
post_join_shape.append(best_alloc_dim)
new_out = alloc(new_join, *post_join_shape)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
...@@ -10,6 +10,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -10,6 +10,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
...@@ -67,10 +68,16 @@ optdb.register( ...@@ -67,10 +68,16 @@ optdb.register(
def local_eager_useless_unbatched_blockwise(fgraph, node): def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance( if isinstance(
node.op.core_op, node.op.core_op,
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor, Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
): ):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectored at runtime, # These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes. # since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node) return local_useless_unbatched_blockwise.fn(fgraph, node)
...@@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node): ...@@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node):
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
""" """
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None
op: Blockwise = node.op # type: ignore op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
if not batch_ndim: if not batch_ndim:
return None return None
if not any(var.owner and isinstance(var.owner.op, Alloc) for var in node.inputs):
return None
new_inputs = [] new_inputs = []
batch_shapes = [] batch_shapes = []
can_push_any_alloc = False can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig): for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, Alloc): if not all(inp.type.broadcastable[:batch_ndim]):
# Push batch dims from Alloc if inp.owner and isinstance(inp.owner.op, Alloc):
value, *shape = inp.owner.inputs # Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim) # Check what to do with the value of the Alloc
missing_ndim = len(shape) - value.type.ndim squeezed_value = _squeeze_left(value, batch_ndim)
if ( missing_ndim = len(shape) - value.type.ndim
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) if (
!= inp.type.broadcastable[batch_ndim:] (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
): != inp.type.broadcastable[batch_ndim:]
# We still need an Alloc for the core dims ):
core_shape = shape[batch_ndim:] # We still need an Alloc for the core dims
# And the batch dims of the squeezed value core_shape = shape[batch_ndim:]
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape) # And the batch dims of the squeezed value
batch_shape = [ squeezed_value_batch_ndim = squeezed_value.type.ndim - len(
1 if broadcastable else dim core_shape
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
) )
] batch_shape = [
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) 1 if broadcastable else dim
if squeezed_value.type.broadcastable == inp.type.broadcastable: for broadcastable, dim in zip(
# We can't change anything about this Alloc input squeezed_value.type.broadcastable[
new_inputs.append(inp) :squeezed_value_batch_ndim
continue ],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
# We can push batch dims of this Alloc input )
batch_shapes.append( ]
tuple( squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
1 if broadcastable else dim if squeezed_value.type.broadcastable == inp.type.broadcastable:
for broadcastable, dim in zip( # We can't change anything about this Alloc input
inp.type.broadcastable, shape[:batch_ndim] new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
)
) )
) )
) new_inputs.append(squeezed_value)
new_inputs.append(squeezed_value) can_push_any_alloc = True
can_push_any_alloc = True continue
else: # Nothing to do with this input other than removing dummy batch dims
# Nothing to do with this input other than removing dummy batch dims new_inputs.append(_squeeze_left(inp, batch_ndim))
new_inputs.append(_squeeze_left(inp, batch_ndim))
if not can_push_any_alloc: if not can_push_any_alloc:
return None return None
...@@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node): ...@@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node):
missing_ndim = old_out_type.ndim - new_out_type.ndim missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
if old_out_type.broadcastable[i]:
continue
for batch_dim in batch_dims: for batch_dim in batch_dims:
if batch_dim == 1: if batch_dim == 1:
continue continue
batch_shape[i] = batch_dim
if isinstance(batch_dim, Constant): if isinstance(batch_dim, Constant):
# Give preference to Constants # Give preference to Constants
batch_shape[i] = batch_dim
break break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim
copy_stack_trace(node.outputs, new_outs) copy_stack_trace(node.outputs, new_outs)
new_outs = [ new_outs = [
...@@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node): ...@@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node):
] ]
copy_stack_trace(node.outputs, new_outs) copy_stack_trace(node.outputs, new_outs)
return new_outs return new_outs
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
For the square Reshape case, we must wait for all the intemediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
batched_shape = x.shape[:batch_ndim]
core_reshape = _squeeze_left(output_shape, batch_ndim)
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
from typing import cast
from pytensor.graph import Apply, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.einsum import Einsum, einsum
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.ofg import inline_ofg_node
from pytensor.tensor.variable import TensorVariable
@register_specialize
@node_rewriter([Einsum])
def optimize_einsum_inner_graph(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Try to optimize an einsum that was not optimizable at definition time.
This can happen when users replace a graph without rebuilding
Or when during the course of rewrites more specialized static shapes are found
"""
op: Einsum = node.op
if op.optimized:
# Already optimized
return None
operands = node.inputs
if any(None in operand.type.shape for operand in operands):
return None
new_out = einsum(op.subscripts, *operands)
assert new_out.owner.op.optimized
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@register_specialize
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op
if not op.optimized:
return None
return cast(list[TensorVariable], inline_ofg_node(node))
from pytensor import clone_replace from typing import cast
from pytensor import Variable, clone_replace
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter from pytensor.graph import Apply, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.tensor.basic import AllocDiag from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
def inline_ofg_node(node: Apply) -> list[Variable]:
op = node.op
assert isinstance(op, OpFromGraph)
inlined_outs = clone_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))
)
copy_stack_trace(op.inner_outputs, inlined_outs)
return cast(list[Variable], inlined_outs)
@node_rewriter([OpFromGraph]) @node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node): def inline_ofg_expansion(fgraph, node):
""" """
...@@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node):
if not op.is_inline: if not op.is_inline:
return False return False
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) return inline_ofg_node(node)
copy_stack_trace(op.inner_outputs, new_out)
return new_out
# We want to run this before the first merge optimizer # We want to run this before the first merge optimizer
...@@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node): ...@@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node):
------- -------
""" """
op = node.op return inline_ofg_node(node)
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
...@@ -749,51 +749,43 @@ pytensor.compile.mode.optdb.register( ...@@ -749,51 +749,43 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op): @register_canonicalize("shape_unsafe")
@node_rewriter([op]) @register_specialize("shape_unsafe")
def f(fgraph, node): @node_rewriter([Reshape])
""" def local_reshape_chain(fgraph, node):
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) """
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
"""
if not check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False
return f
"""
if not check_chain(node, Reshape, Reshape):
return False
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
@register_useless @register_useless("shape_unsafe")
@register_canonicalize @register_canonicalize("shape_unsafe")
@register_stabilize @register_specialize("shape_unsafe")
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_useless_reshape(fgraph, node): def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`. """Remove two kinds of useless `Reshape`.
...@@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node): ...@@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node):
- Remove `Reshape` when reshaping to the shape of the input. - Remove `Reshape` when reshaping to the shape of the input.
""" """
inp = node.inputs[0] inp, output_shape = node.inputs
output = node.outputs[0] [output] = node.outputs
output_shape = node.inputs[1]
if inp.type.ndim != output.type.ndim: if inp.type.ndim != output.type.ndim:
return False return False
# Simple case: both input and output have a single dimension. # Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if ( if (
inp.type.ndim == 1 inp.type.ndim == 1
and output.type.ndim == 1 and output.type.ndim == 1
and all( and inp.type.broadcastable == output.type.broadcastable
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
): ):
return [inp] return [inp]
...@@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node): ...@@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node):
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions # broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): if isinstance(output_shape, Constant) or (
output_shape_is = output_shape.owner.inputs output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
):
if isinstance(output_shape, Constant):
output_shape_is = [
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
]
else:
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None) shape_feature = getattr(fgraph, "shape_feature", None)
...@@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node): ...@@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True shape_match[dim] = True
continue continue
# Match 1 if input.type.shape[dim] == 1 # Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1: if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node): ...@@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node):
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim) inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or ( if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1) extract_constant(inpshp_i, only_process_constants=True)
== extract_constant(outshp_i, only_process_constants=1) == extract_constant(outshp_i, only_process_constants=True)
): ):
shape_match[dim] = True shape_match[dim] = True
continue continue
if all(shape_match) and nb_m1 <= 1: if nb_m1 <= 1 and all(shape_match):
return [inp]
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
return [inp] return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False return False
...@@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
""" """
op = node.op op = node.op
inp = node.inputs[0] inp, output_shape = node.inputs
output = node.outputs[0] [output] = node.outputs
output_shape = node.inputs[1]
dimshuffle_new_order = [] dimshuffle_new_order = []
new_output_shape = [] new_output_shape = []
...@@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_specialize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_reshape_lift(fgraph, node): def local_reshape_lift(fgraph, node):
""" """
......
...@@ -842,13 +842,13 @@ class Reshape(COp): ...@@ -842,13 +842,13 @@ class Reshape(COp):
@_vectorize_node.register(Reshape) @_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape): def _vectorize_reshape(op, node, x, shape):
from pytensor.tensor.blockwise import vectorize_node_fallback
old_x, old_shape = node.inputs old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim batched_ndims = x.type.ndim - old_x.type.ndim
if as_tensor_variable(shape).type.ndim != 1: if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError( return vectorize_node_fallback(op, node, x, shape)
"It is not possible to vectorize the shape argument of Reshape"
)
if len(tuple(old_shape)) == len(tuple(shape)): if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape] new_shape = [*x.shape[:batched_ndims], *shape]
......
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
jax = pytest.importorskip("jax")
def test_jax_einsum():
subscripts = "ij, jk, kl -> il"
x = np.random.rand(3, 5)
y = np.random.rand(5, 2)
z = np.random.rand(2, 4)
shapes = ((3, 5), (5, 2), (2, 4))
x_pt, y_pt, z_pt = (
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
)
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
@pytest.mark.xfail(raises=NotImplementedError)
def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
y = np.random.rand(2, 5)
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
from functools import partial from functools import partial
from pytensor import function import numpy as np
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log from pytensor.scalar import log as scalar_log
from pytensor.tensor import add, alloc, matrix, tensor, tensor3 from pytensor.tensor import add, alloc, matrix, tensor, tensor3
...@@ -9,6 +11,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -9,6 +11,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv from pytensor.tensor.nlinalg import MatrixPinv
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
from pytensor.tensor.shape import Reshape
def test_useless_blockwise_of_elemwise(): def test_useless_blockwise_of_elemwise():
...@@ -45,7 +48,7 @@ def test_blockwise_alloc(): ...@@ -45,7 +48,7 @@ def test_blockwise_alloc():
rewrite = partial( rewrite = partial(
rewrite_graph, rewrite_graph,
include=("ShapeOpt", "specialize"), include=("ShapeOpt", "specialize"),
exclude=("local_useless_unbatched_blockwise",), exclude=("local_useless_unbatched_blockwise", "local_dimshuffle_alloc"),
) )
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)") vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
...@@ -104,7 +107,9 @@ def test_blockwise_alloc(): ...@@ -104,7 +107,9 @@ def test_blockwise_alloc():
y = tensor("y", shape=()) y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5) expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out]) assert equal(
[rewrite(out)], [expected_out]
), None # pytensor.dprint([expected_out, rewrite(out)], print_type=True)
x = tensor("x", shape=(5,)) x = tensor("x", shape=(5,))
y = tensor("y", shape=()) y = tensor("y", shape=())
...@@ -118,3 +123,27 @@ def test_blockwise_alloc(): ...@@ -118,3 +123,27 @@ def test_blockwise_alloc():
out = vector_add(x, alloc(y, 5)) out = vector_add(x, alloc(y, 5))
expected_out = out expected_out = out
assert equal([rewrite(out)], [expected_out]) assert equal([rewrite(out)], [expected_out])
def test_blockwise_reshape():
x = tensor("x", shape=(None, None, None))
y = x.reshape([x.shape[0] * x.shape[1], -1])
new_x = tensor("x", shape=(None, None, None, None))
new_y = vectorize_graph(y, {x: new_x})
assert not isinstance(new_y.owner.op, Reshape)
assert isinstance(new_y.owner.op, Blockwise) and isinstance(
new_y.owner.op.core_op, Reshape
)
rewritten_y = rewrite_graph(
new_y, include=("canonicalize", "specialize"), clone=True
)
assert isinstance(rewritten_y.owner.op, Reshape)
no_rewrites = Mode(linker="py", optimizer=None)
test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2).astype(config.floatX)
np.testing.assert_allclose(
new_y.eval({"x": test_x}, mode=no_rewrites),
rewritten_y.eval({"x": test_x}, mode=no_rewrites),
)
from functools import partial
from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import einsum, specify_shape, tensor
from pytensor.tensor.einsum import Einsum
specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True)
def test_einsum_optimization():
a = tensor("a", shape=(None, None))
b = tensor("b", shape=(None, None))
c = tensor("c", shape=(None, None))
dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c)
assert not dynamic_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(dynamic_shape_einsum)
assert isinstance(rewritten_out.owner.op, Einsum)
a = specify_shape(a, (2, 3))
b = specify_shape(b, (2, 3))
c = specify_shape(c, (3, 5))
static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs(
[a, b, c]
).default_output()
assert not static_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(static_shape_einsum)
# Einsum was inlined because it was optimized
assert not isinstance(rewritten_out.owner.op, Einsum)
# Sanity check that it's not buried in the graph
assert not any(
isinstance(var.owner.op, Einsum)
for var in ancestors([rewritten_out])
if var.owner
)
...@@ -337,6 +337,52 @@ class TestLocalUselessReshape: ...@@ -337,6 +337,52 @@ class TestLocalUselessReshape:
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo) assert not any(isinstance(n.op, Reshape) for n in topo)
def test_constant_shape(self):
# Where reshape is a constant that matches the shape
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
# This could be rewritten as a specify_shape(x, (2, 3))
assert new_out is not x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([3, 2]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is not x
def test_all_but_one_match(self):
x = matrix(shape=(None, None))
shape = [x.shape[0], 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert equal_computations([new_out], [specify_shape(x, (None, 3))])
# Rewrite does not apply if there's also a -1
shape = [-1, 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is out
class TestLocalReshapeToDimshuffle: class TestLocalReshapeToDimshuffle:
def setup_method(self): def setup_method(self):
......
...@@ -3847,8 +3847,10 @@ def test_transpose(): ...@@ -3847,8 +3847,10 @@ def test_transpose():
assert np.all(t2d == np.transpose(x2v, [0, 1])) assert np.all(t2d == np.transpose(x2v, [0, 1]))
assert np.all(t3d == np.transpose(x3v, [0, 2, 1])) assert np.all(t3d == np.transpose(x3v, [0, 2, 1]))
# Check we don't introduce useless transpose
assert ptb.transpose(x1) is x1
# Check that we create a name. # Check that we create a name.
assert ptb.transpose(x1).name == "x1.T"
assert ptb.transpose(x2).name == "x2.T" assert ptb.transpose(x2).name == "x2.T"
assert ptb.transpose(x3).name == "x3.T" assert ptb.transpose(x3).name == "x3.T"
assert ptb.transpose(dmatrix()).name is None assert ptb.transpose(dmatrix()).name is None
......
from functools import partial
from string import ascii_lowercase
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape
# Fail for unexpected warnings in this file
pytestmark = pytest.mark.filterwarnings("error")
floatX = pytensor.config.floatX
ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
for node in fgraph.apply_nodes:
if isinstance(node.op, Blockwise):
if core_op is None:
raise AssertionError
assert not isinstance(node.op.core_op, core_op)
if isinstance(node.op, HasInnerGraph):
# InnerGraph Ops can be rewritten without modifying the original fgraph
if hasattr(node.op, "_fn"):
inner_fgraph = node.op._fn.maker.fgraph
else:
inner_fgraph = node.op.fgraph
assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op)
def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)
np.testing.assert_allclose(
_iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)
def test_delta():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_delta((2, 2), (0, 1)).eval(mode=mode),
[[1.0, 0.0], [0.0, 1.0]],
)
np.testing.assert_allclose(
_delta((2, 2, 2), (0, 1)).eval(mode=mode),
[[[1, 1], [0, 0]], [[0, 0], [1, 1]]],
)
def test_general_dot():
rng = np.random.default_rng(45)
signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)"
tensordot_axes = [(-3, -2), (-1, -4)]
# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
fn = pytensor.function([x, y], out)
# fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph, Reshape)
np_batched_tensordot = np.vectorize(
partial(np.tensordot, axes=tensordot_axes), signature=signature
)
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
np.testing.assert_allclose(
fn(x_test, y_test), np_batched_tensordot(x_test, y_test), atol=ATOL, rtol=RTOL
)
@pytest.mark.parametrize("static_shape_known", [True, False])
@pytest.mark.parametrize(
"signature",
[
"ij",
"ji",
"ii->i",
"ii",
"ij->",
"ij->j",
"ij->i",
"ij,ij->ij",
"ij,ji->ij",
"ij,ji->ji",
"ij,jk",
"kj,ji",
"ij,kj->ik",
"ik,kj->ikj",
"ij,kl->ijkl",
"ij,jk,kl->il",
"kl,ij,jk->il",
"oij,imj,mjkn,lnk,plk->op",
],
)
def test_einsum_signatures(static_shape_known, signature):
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
inputs = signature.split("->")[0].split(",")
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
if static_shape_known:
static_shapes = shapes
else:
static_shapes = [[None] * len(shape) for shape in shapes]
operands = [
pt.tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes)
]
out = pt.einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
rng = np.random.default_rng(37)
test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes]
np_out = np.einsum(signature, *test_values)
fn = function(operands, out)
pt_out = fn(*test_values)
# print(); fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph)
np.testing.assert_allclose(pt_out, np_out, atol=ATOL, rtol=RTOL)
def test_batch_dim():
shapes = (
(7, 3, 5),
(5, 2),
)
x, y = (pt.tensor(name, shape=shape) for name, shape in zip("xy", shapes))
out = pt.einsum("mij,jk->mik", x, y)
assert out.type.shape == (7, 3, 2)
def test_einsum_conv():
# Adapted example from https://medium.com/latinxinai/vectorized-convolution-operation-using-numpy-b122fd52fba3
rng = np.random.default_rng(125)
batch_size = 32
channels = 3
height = 8
width = 8
kernel_size = 2
num_filters = 15
conv_signature = "bchwkt,fckt->bfhw"
windowed_input = rng.random(
size=(batch_size, channels, height, width, kernel_size, kernel_size)
).astype(floatX)
weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size)).astype(
floatX
)
result = einsum(conv_signature, windowed_input, weights).eval()
assert result.shape == (32, 15, 8, 8)
np.testing.assert_allclose(
result,
np.einsum("bchwkt,fckt->bfhw", windowed_input, weights),
atol=ATOL,
rtol=RTOL,
)
def test_ellipsis():
rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test)
with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y)
out = pt.einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
)
# Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1),
atol=ATOL,
rtol=RTOL,
)
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
)
def test_broadcastable_dims():
# Test that einsum handles broadcasting dims correctly. There are two points:
# 1. Numpy einsum allows the same subscript for degenerate and full dimensions
# There is some stale discussion on whether this should be a bug or not, but for now it is not:
# https://github.com/numpy/numpy/issues/11548
# 2. Using the same letter for dimensions that are and aren't broadcastable
# can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32))
a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX)
c_test = rng.normal(size=c.type.shape).astype(floatX)
# Note b is used for both 1 and 32
with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path"
):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
# If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
optimal_eval = optimal_out.eval({a: a_test, b: b_test, c: c_test})
np_eval = np.einsum("ijk,bj,bk->i", a_test, b_test, c_test)
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
...@@ -14,7 +14,7 @@ from pytensor.graph.type import Type ...@@ -14,7 +14,7 @@ from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, as_tensor, constant from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
...@@ -801,8 +801,14 @@ class TestVectorize: ...@@ -801,8 +801,14 @@ class TestVectorize:
[vect_out] = vectorize_node(node, mat, new_shape).outputs [vect_out] = vectorize_node(node, mat, new_shape).outputs
assert equal_computations([vect_out], [reshape(mat, new_shape)]) assert equal_computations([vect_out], [reshape(mat, new_shape)])
with pytest.raises(NotImplementedError): new_shape = stack([[-1, x], [x - 1, -1]], axis=0)
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3))) print(new_shape.type)
[vect_out] = vectorize_node(node, vec, new_shape).outputs
vec_test_value = np.arange(6)
np.testing.assert_allclose(
vect_out.eval({x: 3, vec: vec_test_value}),
np.broadcast_to(vec_test_value.reshape(2, 3), (2, 2, 3)),
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论