提交 0c138495 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make Dot only accept matrix inputs

上级 d1be796e
...@@ -107,7 +107,6 @@ from pytensor.gradient import grad, hessian, jacobian ...@@ -107,7 +107,6 @@ from pytensor.gradient import grad, hessian, jacobian
from pytensor.tensor import ( from pytensor.tensor import (
blas, blas,
blas_c, blas_c,
blas_scipy,
sharedvar, sharedvar,
xlogx, xlogx,
) )
......
...@@ -1801,8 +1801,7 @@ class Alloc(COp): ...@@ -1801,8 +1801,7 @@ class Alloc(COp):
| pytensor.tensor.blas.Gemv | pytensor.tensor.blas.Gemv
| pytensor.tensor.blas_c.CGemv | pytensor.tensor.blas_c.CGemv
| pytensor.tensor.blas.Ger | pytensor.tensor.blas.Ger
| pytensor.tensor.blas_c.CGer | pytensor.tensor.blas_c.CGer,
| pytensor.tensor.blas_scipy.ScipyGer,
) )
): ):
# Ops that will work inplace on the Alloc. So if they # Ops that will work inplace on the Alloc. So if they
......
...@@ -83,6 +83,7 @@ import warnings ...@@ -83,6 +83,7 @@ import warnings
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from scipy.linalg import get_blas_funcs
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
...@@ -288,18 +289,17 @@ class Ger(Op): ...@@ -288,18 +289,17 @@ class Ger(Op):
return Apply(self, inputs, [A.type()]) return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out): def perform(self, node, inputs, output_storage):
cA, calpha, cx, cy = inp A, alpha, x, y = inputs
(cZ,) = out if A.size:
if self.destructive: # GER doesn't handle zero-sized inputs
A = cA ger_func = get_blas_funcs("ger", dtype=A.dtype)
else: if A.flags["C_CONTIGUOUS"]:
A = cA.copy() # Work on transposed system to avoid copying
if calpha != 1: A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
A += calpha * np.outer(cx, cy)
else: else:
A += np.outer(cx, cy) A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
cZ[0] = A output_storage[0][0] = A
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
...@@ -1128,16 +1128,8 @@ class Dot22(GemmRelated): ...@@ -1128,16 +1128,8 @@ class Dot22(GemmRelated):
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs) return Apply(self, [x, y], outputs)
def perform(self, node, inp, out): def perform(self, node, inputs, output_storage):
x, y = inp output_storage[0][0] = np.dot(*inputs)
(z,) = out
try:
z[0] = np.asarray(np.dot(x, y))
except ValueError as e:
# The error raised by numpy has no shape information, we mean to
# add that
e.args = (*e.args, x.shape, y.shape)
raise
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]] return [[input_shapes[0][0], input_shapes[1][1]]]
......
"""
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
from pytensor.tensor.blas import Ger
class ScipyGer(Ger):
def perform(self, node, inputs, output_storage):
from scipy.linalg.blas import get_blas_funcs
cA, calpha, cx, cy = inputs
(cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, which is confusing.
if not self.destructive:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A = A.copy()
elif A.flags["C_CONTIGUOUS"]:
A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T
else:
A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive))
cZ[0] = A
scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True)
...@@ -40,12 +40,13 @@ from pytensor.tensor.elemwise import ( ...@@ -40,12 +40,13 @@ from pytensor.tensor.elemwise import (
get_normalized_batch_axes, get_normalized_batch_axes,
scalar_elemwise, scalar_elemwise,
) )
from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.shape import shape, specify_shape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType, DenseTensorType,
complex_dtypes, complex_dtypes,
continuous_dtypes, continuous_dtypes,
discrete_dtypes, discrete_dtypes,
float_dtypes,
int_dtypes, int_dtypes,
tensor, tensor,
uint_dtypes, uint_dtypes,
...@@ -2986,9 +2987,7 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right")) ...@@ -2986,9 +2987,7 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
class Dot(Op): class Dot(Op):
""" """
Computes the dot product of two variables. For two matrices, this is Computes the dot product of two matrices variables
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Notes Notes
----- -----
...@@ -3001,97 +3000,58 @@ class Dot(Op): ...@@ -3001,97 +3000,58 @@ class Dot(Op):
""" """
gufunc_signature = "(m,n),(n,p)->(m,p)"
gufunc_spec = ("matmul", 2, 1)
__props__ = () __props__ = ()
# the rationale for Dot22 is related to getting GEMM Ops into the def make_node(self, x, y):
# graph. See Dot22 in tensor.blas for details. x = as_tensor_variable(x)
y = as_tensor_variable(y)
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if len(inputs) != 2: if x.type.ndim != 2:
raise TypeError(f"Two arguments required, {len(inputs)} given ")
if inputs[0].ndim not in (1, 2):
raise TypeError( raise TypeError(
"Input 0 (0-indexed) must have ndim of " f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
) )
if inputs[1].ndim not in (1, 2): if y.type.ndim != 2:
raise TypeError( raise TypeError(
"Input 1 (0-indexed) must have ndim of " f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
) )
sx, sy = (input.type.shape for input in inputs) sx, sy = x.type.shape, y.type.shape
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError( raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}" f"Incompatible shared dimension for dot product: {sx}, {sy}"
) )
out_shape = (sx[0], sy[1])
out_dtype = ps.upcast(x.type.dtype, y.type.dtype)
outputs = [tensor(dtype=out_dtype, shape=out_shape)]
return Apply(self, [x, y], outputs)
if len(sy) == 2: def perform(self, node, inputs, output_storage):
sz = sx[:-1] + sy[-1:] output_storage[0][0] = np.matmul(*inputs)
elif len(sy) == 1:
sz = sx[:-1]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs)
def perform(self, node, inp, out):
x, y = inp
(z,) = out
# the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d
# ndarray
z[0] = np.asarray(np.dot(x, y))
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
(gz,) = grads (gz,) = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz * y
ygrad = gz * x
# x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T)
ygrad = outer(x.T, gz)
# x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1:
xgrad = outer(gz, y.T)
ygrad = dot(x.T, gz)
# x is matrix, y is matrix, grad is matrix xgrad = self(gz, y.T)
elif xdim == ydim == 2: ygrad = self(x.T, gz)
xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz)
# If x or y contain broadcastable dimensions but only one of # If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the # them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern. # above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461. # This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable: if xgrad.type.shape != x.type.shape:
xgrad = specify_broadcastable( xgrad = specify_shape(xgrad, x.type.shape)
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) if ygrad.type.shape != y.type.shape:
) ygrad = specify_shape(ygrad, y.type.shape)
if ygrad.broadcastable != y.broadcastable:
ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
rval = xgrad, ygrad
for elem in rval: if xgrad.type.dtype not in float_dtypes:
assert elem.dtype.find("float") != -1 raise TypeError("Dot grad x output must be a float type")
if ygrad.type.dtype not in float_dtypes:
raise TypeError("Dot grad y output must be a float type")
return rval return xgrad, ygrad
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluated at c for a and d for b is # R_op for a \dot b evaluated at c for a and d for b is
...@@ -3116,24 +3076,7 @@ class Dot(Op): ...@@ -3116,24 +3076,7 @@ class Dot(Op):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs return [[xshp[0], yshp[1]]]
# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()
def __str__(self):
return "dot"
_dot = Dot() _dot = Dot()
...@@ -3215,7 +3158,24 @@ def dense_dot(a, b): ...@@ -3215,7 +3158,24 @@ def dense_dot(a, b):
elif a.ndim > 2 or b.ndim > 2: elif a.ndim > 2 or b.ndim > 2:
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]]) return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
else: else:
return _dot(a, b) row_vector = a.ndim == 1
if row_vector:
# Promote to row matrix
a = a[None]
col_vector = b.ndim == 1
if col_vector:
# Promote to column matrix
b = b[:, None]
out = _dot(a, b)
if row_vector:
# If we promoted a to a row matrix, we need to squeeze the first dimension
out = out.squeeze(0)
if col_vector:
# If we promoted b to a column matrix, we need to squeeze the last dimension
out = out.squeeze(-1)
return out
def tensordot( def tensordot(
...@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False): ...@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims)) return log(sum(exp(x), axis=axis, keepdims=keepdims))
_matmul = Blockwise( _matmul = Blockwise(_dot, name="Matmul")
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
)
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
...@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 0 or x2.type.ndim == 0: if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar") raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1: if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2) out = vecdot(x1, x2)
elif x1.type.ndim == 1: elif x1.type.ndim == 1:
out = vecmat(x1, x2) out = vecmat(x1, x2)
elif x2.type.ndim == 1: elif x2.type.ndim == 1:
...@@ -4139,23 +4095,7 @@ def vecmat( ...@@ -4139,23 +4095,7 @@ def vecmat(
@_vectorize_node.register(Dot) @_vectorize_node.register(Dot)
def vectorize_node_dot(op, node, batched_x, batched_y): def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs return matmul(batched_x, batched_y).owner
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_fn = vecdot
case (2, 1):
batch_fn = matvec
case (1, 2):
batch_fn = vecmat
case (2, 2):
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_fn(batched_x, batched_y).owner
def nan_to_num(x, nan=0.0, posinf=None, neginf=None): def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
......
import pytensor.tensor.rewriting.basic import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.blas import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.elemwise
......
...@@ -107,7 +107,6 @@ from pytensor.tensor.math import ( ...@@ -107,7 +107,6 @@ from pytensor.tensor.math import (
) )
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType,
TensorType, TensorType,
integer_dtypes, integer_dtypes,
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
...@@ -580,12 +579,6 @@ class GemmOptimizer(GraphRewriter): ...@@ -580,12 +579,6 @@ class GemmOptimizer(GraphRewriter):
def local_dot_to_dot22(fgraph, node): def local_dot_to_dot22(fgraph, node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below # produces a dot(dimshuffle,dimshuffle) of form 4 below
if not isinstance(node.op, Dot):
return
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False
x, y = node.inputs x, y = node.inputs
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match # TODO: upcast one so the types match
...@@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node): ...@@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node):
return return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
if x.ndim == 2 and y.ndim == 2:
new_out = [_dot22(*node.inputs)] new_out = [_dot22(*node.inputs)]
elif x.ndim == 2 and y.ndim == 1:
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
elif x.ndim == 1 and y.ndim == 2:
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
elif x.ndim == 1 and y.ndim == 1:
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
else:
return
copy_stack_trace(node.outputs, new_out) copy_stack_trace(node.outputs, new_out)
return new_out return new_out
......
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.blas import ger, ger_destructive
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]
@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]
use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks [citation needed], but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=50.2,
)
...@@ -276,15 +276,7 @@ def cholesky_ldotlt(fgraph, node): ...@@ -276,15 +276,7 @@ def cholesky_ldotlt(fgraph, node):
A = node.inputs[0] A = node.inputs[0]
if not ( if not (
A.owner is not None A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul))
and (
(
isinstance(A.owner.op, Dot)
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matmul)
)
): ):
return return
......
...@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.basic import ( ...@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.raise_op import assert_op
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -34,6 +33,7 @@ from pytensor.tensor.basic import ( ...@@ -34,6 +33,7 @@ from pytensor.tensor.basic import (
ones_like, ones_like,
register_infer_shape, register_infer_shape,
switch, switch,
zeros,
zeros_like, zeros_like,
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -44,12 +44,10 @@ from pytensor.tensor.math import ( ...@@ -44,12 +44,10 @@ from pytensor.tensor.math import (
Prod, Prod,
Sum, Sum,
_conj, _conj,
_dot,
_matmul, _matmul,
add, add,
digamma, digamma,
dot, dot,
eq,
erf, erf,
erfc, erfc,
exp, exp,
...@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): ...@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
return consts, origconsts, nonconsts return consts, origconsts, nonconsts
@register_canonicalize @register_canonicalize("shape_unsafe")
@register_stabilize @register_stabilize("shape_unsafe")
@node_rewriter([Dot]) @node_rewriter([Dot])
def local_0_dot_x(fgraph, node): def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot): x, y = node.inputs
return False if (
x = node.inputs[0]
y = node.inputs[1]
replace = (
get_underlying_scalar_constant_value( get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False x, only_process_constants=True, raise_not_constant=False
) )
...@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node): ...@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node):
y, only_process_constants=True, raise_not_constant=False y, only_process_constants=True, raise_not_constant=False
) )
== 0 == 0
) ):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]
@register_canonicalize @register_canonicalize
@node_rewriter([DimShuffle]) @node_rewriter([Dot, _matmul])
def local_lift_transpose_through_dot(fgraph, node): def local_lift_transpose_through_dot(fgraph, node):
r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``. r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``.
...@@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node): ...@@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node):
and to later merge consecutive `DimShuffle`\s. and to later merge consecutive `DimShuffle`\s.
""" """
if not ( clients = fgraph.clients[node.out]
is_matrix_transpose(node.outputs[0])
and node.inputs[0].owner if len(clients) != 1:
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul)) # If the dot is used in more than one place, we don't want to duplicate it
): return None
return False
[(client, _)] = clients
x, y = node.inputs[0].owner.inputs if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)):
return None
if x.ndim >= y.ndim >= 2: x, y = node.inputs
# Output is dot product of transposed inputs in reverse order # Output is dot product of transposed inputs in reverse order
ret = [dot_op(y.mT, x.mT)] ret = node.op(y.mT, x.mT)
# Copy over stack trace to output from result of dot-product # Copy over stack trace to output from result of dot-product
copy_stack_trace(node.inputs[0], ret) copy_stack_trace(node.out, ret)
return ret
return {client.out: ret}
def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool): def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
...@@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node): ...@@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([_matmul]) @node_rewriter([_matmul, Dot])
def local_blockwise_dot_to_mul(fgraph, node): def local_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation. """Rewrite dots that correspond to multiplication without summation.
We don't touch the regular dot, to not interfere with the BLAS optimizations. We don't touch outer product without batch-dimensions, to allow rewriting into GER,
which seems more performant in that case.
# TODO: Once we blockwise Blas operations we shouldn't do it for outer product with batch-dimensions either
# TODO: We may still want to canonicalize outer dot as mul, and detect that for GER.
""" """
a, b = node.inputs a, b = node.inputs
a_static_shape = a.type.shape a_static_shape = a.type.shape
b_static_shape = b.type.shape b_static_shape = b.type.shape
core_a_ndim = len(node.op.inputs_sig[0])
core_b_ndim = len(node.op.inputs_sig[1])
if core_a_ndim > 2 or core_b_ndim > 2: # Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# Shouldn't happen, but here just in case if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1):
return None return None
if core_b_ndim == 1: # If it's a core Dot we only rewrite if there's no outer product
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1: # (1, 1) * (1, n) or (m, 1) * (1, 1)
if core_a_ndim == 1: # Otherwise we leave as is, so GER can be used instead
# inner product: (..., 1) * (..., 1) -> (...) if isinstance(node.op, Dot) and not (
# just squeeze the last dimensions of a and b a_static_shape[-2] == 1 or b_static_shape[-1] == 1
new_a = a.squeeze(-1) ):
new_b = b.squeeze(-1)
else:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension of b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a = a.squeeze(-1)
new_b = b
else:
return None
else:
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
if core_a_ndim == 1:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a = a
new_b = b.squeeze(-2)
else:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a = a
new_b = b
else:
return None return None
new_a = copy_stack_trace(a, new_a) new_out = mul(a, b)
new_b = copy_stack_trace(b, new_b) copy_stack_trace(node.out, new_out)
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
return [new_out] return [new_out]
......
...@@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node):
a = a.type.clone(shape=a.type.shape[batch_ndim:])() a = a.type.clone(shape=a.type.shape[batch_ndim:])()
b = b.type.clone(shape=b.type.shape[batch_ndim:])() b = b.type.clone(shape=b.type.shape[batch_ndim:])()
a_ndim = a.ndim a_indices = idx_list[:1]
b_ndim = b.ndim b_indices = (slice(None), *idx_list[1:])
num_a_indices = min(a_ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b_ndim > 1 and len(b_indices) >= b_ndim - 1:
b_indices = (
b_indices[: b_ndim - 2]
+ (slice(None, None, None),)
+ b_indices[b_ndim - 2 :]
)
a_sub = a[tuple(a_indices)] a_sub = a[tuple(a_indices)]
b_sub = b[tuple(b_indices)] if b_indices else b b_sub = b[tuple(b_indices)]
r = dot(a_sub, b_sub) r = dot(a_sub, b_sub)
if batch_ndim: if batch_ndim:
......
...@@ -37,51 +37,51 @@ def clear_assoccomm(): ...@@ -37,51 +37,51 @@ def clear_assoccomm():
def test_kanren_basic(): def test_kanren_basic():
A_pt = pt.matrix("A") A_pt = pt.matrix("A")
x_pt = pt.vector("x") B_pt = pt.matrix("B")
y_pt = pt.dot(A_pt, x_pt) y_pt = pt.dot(A_pt, B_pt)
q = var() q = var()
res = list(run(None, q, eq(y_pt, etuple(_dot, q, x_pt)))) res = list(run(None, q, eq(y_pt, etuple(_dot, q, B_pt))))
assert res == [A_pt] assert res == [A_pt]
def test_KanrenRelationSub_filters(): def test_KanrenRelationSub_filters():
x_pt = pt.vector("x")
y_pt = pt.vector("y")
z_pt = pt.vector("z")
A_pt = pt.matrix("A") A_pt = pt.matrix("A")
B_pt = pt.matrix("B")
C_pt = pt.matrix("C")
D_pt = pt.matrix("D")
fact(commutative, _dot) fact(commutative, _dot)
fact(commutative, pt.add) fact(commutative, pt.add)
fact(associative, pt.add) fact(associative, pt.add)
Z_pt = A_pt.dot((x_pt + y_pt) + z_pt) Z_pt = A_pt.dot((B_pt + C_pt) + D_pt)
fgraph = FunctionGraph(outputs=[Z_pt], clone=False) fgraph = FunctionGraph(outputs=[Z_pt], clone=False)
def distributes(in_lv, out_lv): def distributes(in_lv, out_lv):
A_lv, x_lv, y_lv, z_lv = vars(4) A_lv, B_lv, C_lv, D_lv = vars(4)
return lall( return lall(
# lhs == A * (x + y + z) # lhs == A * (x + y + z)
eq_assoccomm( eq_assoccomm(
etuple(_dot, A_lv, etuple(pt.add, x_lv, etuple(pt.add, y_lv, z_lv))), etuple(_dot, A_lv, etuple(pt.add, B_lv, etuple(pt.add, C_lv, D_lv))),
in_lv, in_lv,
), ),
# This relation does nothing but provide us with a means of # This relation does nothing but provide us with a means of
# generating associative-commutative matches in the `kanren` # generating associative-commutative matches in the `kanren`
# output. # output.
eq((A_lv, x_lv, y_lv, z_lv), out_lv), eq((A_lv, B_lv, C_lv, D_lv), out_lv),
) )
def results_filter(results): def results_filter(results):
_results = [eval_if_etuple(v) for v in results] _results = [eval_if_etuple(v) for v in results]
# Make sure that at least a couple permutations are present # Make sure that at least a couple permutations are present
assert (A_pt, x_pt, y_pt, z_pt) in _results assert (A_pt, B_pt, C_pt, D_pt) in _results
assert (A_pt, y_pt, x_pt, z_pt) in _results assert (A_pt, C_pt, B_pt, D_pt) in _results
assert (A_pt, z_pt, x_pt, y_pt) in _results assert (A_pt, D_pt, B_pt, C_pt) in _results
return None return None
...@@ -121,13 +121,13 @@ def test_KanrenRelationSub_multiout(): ...@@ -121,13 +121,13 @@ def test_KanrenRelationSub_multiout():
def test_KanrenRelationSub_dot(): def test_KanrenRelationSub_dot():
"""Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached.""" """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
x_pt = pt.vector("x")
c_pt = pt.vector("c")
d_pt = pt.vector("d")
A_pt = pt.matrix("A") A_pt = pt.matrix("A")
B_pt = pt.matrix("B") B_pt = pt.matrix("B")
C_pt = pt.matrix("C")
D_pt = pt.matrix("D")
E_pt = pt.matrix("E")
Z_pt = A_pt.dot(x_pt + B_pt.dot(c_pt + d_pt)) Z_pt = A_pt.dot(E_pt + B_pt.dot(C_pt + D_pt))
fgraph = FunctionGraph(outputs=[Z_pt], clone=False) fgraph = FunctionGraph(outputs=[Z_pt], clone=False)
...@@ -137,15 +137,15 @@ def test_KanrenRelationSub_dot(): ...@@ -137,15 +137,15 @@ def test_KanrenRelationSub_dot():
return lall( return lall(
# lhs == A * (x + b) # lhs == A * (x + b)
eq( eq(
etuple(_dot, var("A"), etuple(pt.add, var("x"), var("b"))), etuple(_dot, var("A"), etuple(pt.add, var("E"), var("B"))),
in_lv, in_lv,
), ),
# rhs == A * x + A * b # rhs == A * x + A * b
eq( eq(
etuple( etuple(
pt.add, pt.add,
etuple(_dot, var("A"), var("x")), etuple(_dot, var("A"), var("E")),
etuple(_dot, var("A"), var("b")), etuple(_dot, var("A"), var("B")),
), ),
out_lv, out_lv,
), ),
......
...@@ -631,7 +631,7 @@ def test_Dot(x, y): ...@@ -631,7 +631,7 @@ def test_Dot(x, y):
x, x_test_value = x x, x_test_value = x
y, y_test_value = y y, y_test_value = y
g = ptm.Dot()(x, y) g = ptm.dot(x, y)
compare_numba_and_py( compare_numba_and_py(
[x, y], [x, y],
......
...@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ...@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
== 1 == 1
) )
# For now rewrite only applies to Batched Dots # For now we do not rewrite only the case of unbatched outer
core_outer = (not batched) and (a_shape == (3, 1)) and (b_shape == (1, 3))
rewritten_out = rewrite_graph(out) rewritten_out = rewrite_graph(out)
assert rewritten_out.type.shape == out.type.shape assert rewritten_out.type.shape == out.type.shape
assert sum( assert sum(
isinstance(var.owner.op, (Blockwise | Dot)) isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([rewritten_out]) for var in ancestors([rewritten_out])
if var.owner if var.owner
) == (0 if batched else 1) ) == (1 if core_outer else 0)
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype) a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype) b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
......
...@@ -9,7 +9,6 @@ from numpy.testing import assert_array_almost_equal ...@@ -9,7 +9,6 @@ from numpy.testing import assert_array_almost_equal
import pytensor import pytensor
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.blas_scipy
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.io import In from pytensor.compile.io import In
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
......
...@@ -8,7 +8,6 @@ import pytensor.tensor as pt ...@@ -8,7 +8,6 @@ import pytensor.tensor as pt
from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
from tests import unittest_tools from tests import unittest_tools
from tests.tensor.test_blas import BaseGemv, TestBlasStrides from tests.tensor.test_blas import BaseGemv, TestBlasStrides
...@@ -68,8 +67,6 @@ class TestCGer(OptimizationTestMixin): ...@@ -68,8 +67,6 @@ class TestCGer(OptimizationTestMixin):
assert CGer(False) == CGer(False) assert CGer(False) == CGer(False)
assert CGer(False) != CGer(True) assert CGer(False) != CGer(True)
assert CGer(True) != ScipyGer(True)
assert CGer(False) != ScipyGer(False)
assert CGer(True) != Ger(True) assert CGer(True) != Ger(True)
assert CGer(False) != Ger(False) assert CGer(False) != Ger(False)
......
import pickle
import numpy as np
import pytensor
from pytensor import tensor as pt
from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.math import outer
from pytensor.tensor.type import tensor
from tests.tensor.test_blas import TestBlasStrides, gemm_no_inplace
from tests.unittest_tools import OptimizationTestMixin
class TestScipyGer(OptimizationTestMixin):
def setup_method(self):
self.mode = pytensor.compile.get_default_mode()
self.mode = self.mode.including("fast_run")
self.mode = self.mode.excluding("c_blas") # c_blas trumps scipy Ops
dtype = self.dtype = "float64" # optimization isn't dtype-dependent
self.A = tensor(dtype=dtype, shape=(None, None))
self.a = tensor(dtype=dtype, shape=())
self.x = tensor(dtype=dtype, shape=(None,))
self.y = tensor(dtype=dtype, shape=(None,))
self.Aval = np.ones((2, 3), dtype=dtype)
self.xval = np.asarray([1, 2], dtype=dtype)
self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype)
def function(self, inputs, outputs):
return pytensor.function(inputs, outputs, self.mode)
def run_f(self, f):
f(self.Aval, self.xval, self.yval)
f(self.Aval[::-1, ::-1], self.xval[::-1], self.yval[::-1])
def b(self, bval):
return pt.as_tensor_variable(np.asarray(bval, dtype=self.dtype))
def test_outer(self):
f = self.function([self.x, self.y], outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=True))
def test_A_plus_outer(self):
f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) # DebugMode tests correctness
def test_A_plus_scaled_outer(self):
f = self.function(
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) # DebugMode tests correctness
def test_scaled_A_plus_scaled_outer(self):
f = self.function(
[self.A, self.x, self.y], 0.2 * self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, gemm_no_inplace)
self.run_f(f) # DebugMode tests correctness
def test_pickle(self):
out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y)
f = pytensor.function([self.A, self.a, self.x, self.y], out)
new_f = pickle.loads(pickle.dumps(f))
assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer)
assert np.allclose(
f(self.Aval, 1.0, self.xval, self.yval),
new_f(self.Aval, 1.0, self.xval, self.yval),
)
class TestBlasStridesScipy(TestBlasStrides):
mode = pytensor.compile.get_default_mode()
mode = mode.including("fast_run").excluding("gpu", "c_blas")
...@@ -1998,50 +1998,20 @@ class TestMean: ...@@ -1998,50 +1998,20 @@ class TestMean:
assert mean(ll).eval() == 1 assert mean(ll).eval() == 1
def test_dot_numpy_inputs():
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a = np.ones(2)
b = np.ones(2)
res = dot(a, b)
assert isinstance(res, Variable)
assert isinstance(res.owner.op, Dot)
class TestDot: class TestDot:
def test_Op_dims(self): def test_valid_ndim(self):
d0 = scalar() d0 = scalar()
d1 = vector() d1 = vector()
d2 = matrix() d2 = matrix()
d3 = tensor3() d3 = tensor3()
with pytest.raises(TypeError):
_dot(d0, d0)
with pytest.raises(TypeError):
_dot(d0, d1)
with pytest.raises(TypeError): with pytest.raises(TypeError):
_dot(d0, d2) _dot(d0, d2)
with pytest.raises(TypeError): with pytest.raises(TypeError):
_dot(d0, d3)
with pytest.raises(TypeError):
_dot(d1, d0)
_dot(d1, d1)
_dot(d1, d2) _dot(d1, d2)
with pytest.raises(TypeError):
_dot(d1, d3)
with pytest.raises(TypeError):
_dot(d2, d0)
_dot(d2, d1)
_dot(d2, d2)
with pytest.raises(TypeError):
_dot(d2, d3)
with pytest.raises(TypeError):
_dot(d3, d0)
with pytest.raises(TypeError):
_dot(d3, d1)
with pytest.raises(TypeError): with pytest.raises(TypeError):
_dot(d3, d2) _dot(d3, d2)
with pytest.raises(TypeError): _dot(d2, d2) # Fine
_dot(d3, d3)
def test_grad(self): def test_grad(self):
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
...@@ -2089,6 +2059,14 @@ class TestDot: ...@@ -2089,6 +2059,14 @@ class TestDot:
g = grad(z.sum(), y) g = grad(z.sum(), y)
assert is_super_shape(y, g) assert is_super_shape(y, g)
def test_dot_numpy_inputs(self):
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a = np.ones((2, 2))
b = np.ones((2, 2))
res = dot(a, b)
assert isinstance(res, Variable)
assert isinstance(res.owner.op, Dot)
def test_matrix_vector_ops(): def test_matrix_vector_ops():
"""Test vecdot, matvec, and vecmat helper functions.""" """Test vecdot, matvec, and vecmat helper functions."""
...@@ -2796,7 +2774,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2796,7 +2774,7 @@ class TestInferShape(utt.InferShapeTester):
bdvec_val = random(4, rng=rng) bdvec_val = random(4, rng=rng)
self._compile_and_check( self._compile_and_check(
[advec, bdvec], [advec, bdvec],
[Dot()(advec, bdvec)], [dot(advec, bdvec)],
[advec_val, bdvec_val], [advec_val, bdvec_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
) )
...@@ -2808,7 +2786,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2808,7 +2786,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val = random(5, 3, rng=rng) bdmat_val = random(5, 3, rng=rng)
self._compile_and_check( self._compile_and_check(
[admat, bdmat], [admat, bdmat],
[Dot()(admat, bdmat)], [dot(admat, bdmat)],
[admat_val, bdmat_val], [admat_val, bdmat_val],
(Dot, blas.Dot22), (Dot, blas.Dot22),
) )
...@@ -2817,7 +2795,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2817,7 +2795,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val = random(4, 5, rng=rng) bdmat_val = random(4, 5, rng=rng)
self._compile_and_check( self._compile_and_check(
[advec, bdmat], [advec, bdmat],
[Dot()(advec, bdmat)], [dot(advec, bdmat)],
[advec_val, bdmat_val], [advec_val, bdmat_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
) )
...@@ -2826,7 +2804,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2826,7 +2804,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(5, 4, rng=rng) admat_val = random(5, 4, rng=rng)
self._compile_and_check( self._compile_and_check(
[admat, bdvec], [admat, bdvec],
[Dot()(admat, bdvec)], [dot(admat, bdvec)],
[admat_val, bdvec_val], [admat_val, bdvec_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
) )
......
...@@ -333,7 +333,7 @@ def test_debugprint(): ...@@ -333,7 +333,7 @@ def test_debugprint():
def test_debugprint_id_type(): def test_debugprint_id_type():
a_at = dvector() a_at = dmatrix()
b_at = dmatrix() b_at = dmatrix()
d_at = b_at.dot(a_at) d_at = b_at.dot(a_at)
...@@ -344,10 +344,10 @@ def test_debugprint_id_type(): ...@@ -344,10 +344,10 @@ def test_debugprint_id_type():
s = s.getvalue() s = s.getvalue()
exp_res = f"""Add [id {e_at.auto_name}] exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}] ├─ Dot [id {d_at.auto_name}]
│ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}] │ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}]
│ └─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}] │ └─ <Matrix(float64, shape=(?, ?))> [id {a_at.auto_name}]
└─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}] └─ <Matrix(float64, shape=(?, ?))> [id {a_at.auto_name}]
""" """
assert [l.strip() for l in s.split("\n")] == [ assert [l.strip() for l in s.split("\n")] == [
......
...@@ -312,5 +312,7 @@ def test_dot_errors(): ...@@ -312,5 +312,7 @@ def test_dot_errors():
x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
# Doesn't fail until the rewrite # Doesn't fail until the rewrite
with pytest.raises(ValueError, match="not aligned"): with pytest.raises(
ValueError, match="Input operand 1 has a mismatch in its core dimension 0"
):
fn(x_test, y_test) fn(x_test, y_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论