提交 0c138495 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make Dot only accept matrix inputs

上级 d1be796e
......@@ -107,7 +107,6 @@ from pytensor.gradient import grad, hessian, jacobian
from pytensor.tensor import (
blas,
blas_c,
blas_scipy,
sharedvar,
xlogx,
)
......
......@@ -1801,8 +1801,7 @@ class Alloc(COp):
| pytensor.tensor.blas.Gemv
| pytensor.tensor.blas_c.CGemv
| pytensor.tensor.blas.Ger
| pytensor.tensor.blas_c.CGer
| pytensor.tensor.blas_scipy.ScipyGer,
| pytensor.tensor.blas_c.CGer,
)
):
# Ops that will work inplace on the Alloc. So if they
......
......@@ -83,6 +83,7 @@ import warnings
from pathlib import Path
import numpy as np
from scipy.linalg import get_blas_funcs
from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
......@@ -288,18 +289,17 @@ class Ger(Op):
return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out):
cA, calpha, cx, cy = inp
(cZ,) = out
if self.destructive:
A = cA
else:
A = cA.copy()
if calpha != 1:
A += calpha * np.outer(cx, cy)
else:
A += np.outer(cx, cy)
cZ[0] = A
def perform(self, node, inputs, output_storage):
A, alpha, x, y = inputs
if A.size:
# GER doesn't handle zero-sized inputs
ger_func = get_blas_funcs("ger", dtype=A.dtype)
if A.flags["C_CONTIGUOUS"]:
# Work on transposed system to avoid copying
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
else:
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
output_storage[0][0] = A
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
......@@ -1128,16 +1128,8 @@ class Dot22(GemmRelated):
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs)
def perform(self, node, inp, out):
x, y = inp
(z,) = out
try:
z[0] = np.asarray(np.dot(x, y))
except ValueError as e:
# The error raised by numpy has no shape information, we mean to
# add that
e.args = (*e.args, x.shape, y.shape)
raise
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.dot(*inputs)
def infer_shape(self, fgraph, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
......
"""
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
from pytensor.tensor.blas import Ger
class ScipyGer(Ger):
def perform(self, node, inputs, output_storage):
from scipy.linalg.blas import get_blas_funcs
cA, calpha, cx, cy = inputs
(cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, which is confusing.
if not self.destructive:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A = A.copy()
elif A.flags["C_CONTIGUOUS"]:
A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T
else:
A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive))
cZ[0] = A
scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True)
......@@ -40,12 +40,13 @@ from pytensor.tensor.elemwise import (
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.shape import shape, specify_shape
from pytensor.tensor.type import (
DenseTensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
float_dtypes,
int_dtypes,
tensor,
uint_dtypes,
......@@ -2986,9 +2987,7 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
class Dot(Op):
"""
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Computes the dot product of two matrices variables
Notes
-----
......@@ -3001,97 +3000,58 @@ class Dot(Op):
"""
gufunc_signature = "(m,n),(n,p)->(m,p)"
gufunc_spec = ("matmul", 2, 1)
__props__ = ()
# the rationale for Dot22 is related to getting GEMM Ops into the
# graph. See Dot22 in tensor.blas for details.
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
def make_node(self, x, y):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if len(inputs) != 2:
raise TypeError(f"Two arguments required, {len(inputs)} given ")
if inputs[0].ndim not in (1, 2):
if x.type.ndim != 2:
raise TypeError(
"Input 0 (0-indexed) must have ndim of "
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
)
if inputs[1].ndim not in (1, 2):
if y.type.ndim != 2:
raise TypeError(
"Input 1 (0-indexed) must have ndim of "
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
)
sx, sy = (input.type.shape for input in inputs)
sx, sy = x.type.shape, y.type.shape
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)
out_shape = (sx[0], sy[1])
out_dtype = ps.upcast(x.type.dtype, y.type.dtype)
outputs = [tensor(dtype=out_dtype, shape=out_shape)]
return Apply(self, [x, y], outputs)
if len(sy) == 2:
sz = sx[:-1] + sy[-1:]
elif len(sy) == 1:
sz = sx[:-1]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs)
def perform(self, node, inp, out):
x, y = inp
(z,) = out
# the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d
# ndarray
z[0] = np.asarray(np.dot(x, y))
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.matmul(*inputs)
def grad(self, inp, grads):
x, y = inp
(gz,) = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz * y
ygrad = gz * x
# x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T)
ygrad = outer(x.T, gz)
# x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1:
xgrad = outer(gz, y.T)
ygrad = dot(x.T, gz)
# x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2:
xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz)
xgrad = self(gz, y.T)
ygrad = self(x.T, gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable:
ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
rval = xgrad, ygrad
if xgrad.type.shape != x.type.shape:
xgrad = specify_shape(xgrad, x.type.shape)
if ygrad.type.shape != y.type.shape:
ygrad = specify_shape(ygrad, y.type.shape)
for elem in rval:
assert elem.dtype.find("float") != -1
if xgrad.type.dtype not in float_dtypes:
raise TypeError("Dot grad x output must be a float type")
if ygrad.type.dtype not in float_dtypes:
raise TypeError("Dot grad y output must be a float type")
return rval
return xgrad, ygrad
def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluated at c for a and d for b is
......@@ -3116,24 +3076,7 @@ class Dot(Op):
def infer_shape(self, fgraph, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()
def __str__(self):
return "dot"
return [[xshp[0], yshp[1]]]
_dot = Dot()
......@@ -3215,7 +3158,24 @@ def dense_dot(a, b):
elif a.ndim > 2 or b.ndim > 2:
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
else:
return _dot(a, b)
row_vector = a.ndim == 1
if row_vector:
# Promote to row matrix
a = a[None]
col_vector = b.ndim == 1
if col_vector:
# Promote to column matrix
b = b[:, None]
out = _dot(a, b)
if row_vector:
# If we promoted a to a row matrix, we need to squeeze the first dimension
out = out.squeeze(0)
if col_vector:
# If we promoted b to a column matrix, we need to squeeze the last dimension
out = out.squeeze(-1)
return out
def tensordot(
......@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))
_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
)
_matmul = Blockwise(_dot, name="Matmul")
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
......@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
out = vecdot(x1, x2)
elif x1.type.ndim == 1:
out = vecmat(x1, x2)
elif x2.type.ndim == 1:
......@@ -4139,23 +4095,7 @@ def vecmat(
@_vectorize_node.register(Dot)
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_fn = vecdot
case (2, 1):
batch_fn = matvec
case (1, 2):
batch_fn = vecmat
case (2, 2):
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_fn(batched_x, batched_y).owner
return matmul(batched_x, batched_y).owner
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
......
import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
......
......@@ -107,7 +107,6 @@ from pytensor.tensor.math import (
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import (
DenseTensorType,
TensorType,
integer_dtypes,
values_eq_approx_remove_inf_nan,
......@@ -580,12 +579,6 @@ class GemmOptimizer(GraphRewriter):
def local_dot_to_dot22(fgraph, node):
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
if not isinstance(node.op, Dot):
return
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False
x, y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
......@@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node):
return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
if x.ndim == 2 and y.ndim == 2:
new_out = [_dot22(*node.inputs)]
elif x.ndim == 2 and y.ndim == 1:
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
elif x.ndim == 1 and y.ndim == 2:
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
elif x.ndim == 1 and y.ndim == 1:
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
else:
return
new_out = [_dot22(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
......
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.blas import ger, ger_destructive
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]
@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]
use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks [citation needed], but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=50.2,
)
......@@ -276,15 +276,7 @@ def cholesky_ldotlt(fgraph, node):
A = node.inputs[0]
if not (
A.owner is not None
and (
(
isinstance(A.owner.op, Dot)
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matmul)
)
A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul))
):
return
......
......@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.raise_op import assert_op
from pytensor.tensor.basic import (
Alloc,
Join,
......@@ -34,6 +33,7 @@ from pytensor.tensor.basic import (
ones_like,
register_infer_shape,
switch,
zeros,
zeros_like,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -44,12 +44,10 @@ from pytensor.tensor.math import (
Prod,
Sum,
_conj,
_dot,
_matmul,
add,
digamma,
dot,
eq,
erf,
erfc,
exp,
......@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
return consts, origconsts, nonconsts
@register_canonicalize
@register_stabilize
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([Dot])
def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot):
return False
x = node.inputs[0]
y = node.inputs[1]
replace = (
x, y = node.inputs
if (
get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False
)
......@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node):
y, only_process_constants=True, raise_not_constant=False
)
== 0
)
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]
):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
@register_canonicalize
@node_rewriter([DimShuffle])
@node_rewriter([Dot, _matmul])
def local_lift_transpose_through_dot(fgraph, node):
r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``.
......@@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node):
and to later merge consecutive `DimShuffle`\s.
"""
if not (
is_matrix_transpose(node.outputs[0])
and node.inputs[0].owner
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul))
):
return False
clients = fgraph.clients[node.out]
x, y = node.inputs[0].owner.inputs
if len(clients) != 1:
# If the dot is used in more than one place, we don't want to duplicate it
return None
if x.ndim >= y.ndim >= 2:
# Output is dot product of transposed inputs in reverse order
ret = [dot_op(y.mT, x.mT)]
[(client, _)] = clients
# Copy over stack trace to output from result of dot-product
copy_stack_trace(node.inputs[0], ret)
return ret
if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)):
return None
x, y = node.inputs
# Output is dot product of transposed inputs in reverse order
ret = node.op(y.mT, x.mT)
# Copy over stack trace to output from result of dot-product
copy_stack_trace(node.out, ret)
return {client.out: ret}
def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
......@@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter([_matmul])
def local_blockwise_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation.
@node_rewriter([_matmul, Dot])
def local_dot_to_mul(fgraph, node):
"""Rewrite dots that correspond to multiplication without summation.
We don't touch the regular dot, to not interfere with the BLAS optimizations.
We don't touch outer product without batch-dimensions, to allow rewriting into GER,
which seems more performant in that case.
# TODO: Once we blockwise Blas operations we shouldn't do it for outer product with batch-dimensions either
# TODO: We may still want to canonicalize outer dot as mul, and detect that for GER.
"""
a, b = node.inputs
a_static_shape = a.type.shape
b_static_shape = b.type.shape
core_a_ndim = len(node.op.inputs_sig[0])
core_b_ndim = len(node.op.inputs_sig[1])
if core_a_ndim > 2 or core_b_ndim > 2:
# Shouldn't happen, but here just in case
# Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1):
return None
if core_b_ndim == 1:
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1:
if core_a_ndim == 1:
# inner product: (..., 1) * (..., 1) -> (...)
# just squeeze the last dimensions of a and b
new_a = a.squeeze(-1)
new_b = b.squeeze(-1)
else:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension of b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a = a.squeeze(-1)
new_b = b
else:
return None
else:
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
if core_a_ndim == 1:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a = a
new_b = b.squeeze(-2)
else:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a = a
new_b = b
else:
return None
# If it's a core Dot we only rewrite if there's no outer product
# (1, 1) * (1, n) or (m, 1) * (1, 1)
# Otherwise we leave as is, so GER can be used instead
if isinstance(node.op, Dot) and not (
a_static_shape[-2] == 1 or b_static_shape[-1] == 1
):
return None
new_a = copy_stack_trace(a, new_a)
new_b = copy_stack_trace(b, new_b)
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
new_out = mul(a, b)
copy_stack_trace(node.out, new_out)
return [new_out]
......
......@@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node):
a = a.type.clone(shape=a.type.shape[batch_ndim:])()
b = b.type.clone(shape=b.type.shape[batch_ndim:])()
a_ndim = a.ndim
b_ndim = b.ndim
num_a_indices = min(a_ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b_ndim > 1 and len(b_indices) >= b_ndim - 1:
b_indices = (
b_indices[: b_ndim - 2]
+ (slice(None, None, None),)
+ b_indices[b_ndim - 2 :]
)
a_indices = idx_list[:1]
b_indices = (slice(None), *idx_list[1:])
a_sub = a[tuple(a_indices)]
b_sub = b[tuple(b_indices)] if b_indices else b
b_sub = b[tuple(b_indices)]
r = dot(a_sub, b_sub)
if batch_ndim:
......
......@@ -37,51 +37,51 @@ def clear_assoccomm():
def test_kanren_basic():
A_pt = pt.matrix("A")
x_pt = pt.vector("x")
B_pt = pt.matrix("B")
y_pt = pt.dot(A_pt, x_pt)
y_pt = pt.dot(A_pt, B_pt)
q = var()
res = list(run(None, q, eq(y_pt, etuple(_dot, q, x_pt))))
res = list(run(None, q, eq(y_pt, etuple(_dot, q, B_pt))))
assert res == [A_pt]
def test_KanrenRelationSub_filters():
x_pt = pt.vector("x")
y_pt = pt.vector("y")
z_pt = pt.vector("z")
A_pt = pt.matrix("A")
B_pt = pt.matrix("B")
C_pt = pt.matrix("C")
D_pt = pt.matrix("D")
fact(commutative, _dot)
fact(commutative, pt.add)
fact(associative, pt.add)
Z_pt = A_pt.dot((x_pt + y_pt) + z_pt)
Z_pt = A_pt.dot((B_pt + C_pt) + D_pt)
fgraph = FunctionGraph(outputs=[Z_pt], clone=False)
def distributes(in_lv, out_lv):
A_lv, x_lv, y_lv, z_lv = vars(4)
A_lv, B_lv, C_lv, D_lv = vars(4)
return lall(
# lhs == A * (x + y + z)
eq_assoccomm(
etuple(_dot, A_lv, etuple(pt.add, x_lv, etuple(pt.add, y_lv, z_lv))),
etuple(_dot, A_lv, etuple(pt.add, B_lv, etuple(pt.add, C_lv, D_lv))),
in_lv,
),
# This relation does nothing but provide us with a means of
# generating associative-commutative matches in the `kanren`
# output.
eq((A_lv, x_lv, y_lv, z_lv), out_lv),
eq((A_lv, B_lv, C_lv, D_lv), out_lv),
)
def results_filter(results):
_results = [eval_if_etuple(v) for v in results]
# Make sure that at least a couple permutations are present
assert (A_pt, x_pt, y_pt, z_pt) in _results
assert (A_pt, y_pt, x_pt, z_pt) in _results
assert (A_pt, z_pt, x_pt, y_pt) in _results
assert (A_pt, B_pt, C_pt, D_pt) in _results
assert (A_pt, C_pt, B_pt, D_pt) in _results
assert (A_pt, D_pt, B_pt, C_pt) in _results
return None
......@@ -121,13 +121,13 @@ def test_KanrenRelationSub_multiout():
def test_KanrenRelationSub_dot():
"""Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
x_pt = pt.vector("x")
c_pt = pt.vector("c")
d_pt = pt.vector("d")
A_pt = pt.matrix("A")
B_pt = pt.matrix("B")
C_pt = pt.matrix("C")
D_pt = pt.matrix("D")
E_pt = pt.matrix("E")
Z_pt = A_pt.dot(x_pt + B_pt.dot(c_pt + d_pt))
Z_pt = A_pt.dot(E_pt + B_pt.dot(C_pt + D_pt))
fgraph = FunctionGraph(outputs=[Z_pt], clone=False)
......@@ -137,15 +137,15 @@ def test_KanrenRelationSub_dot():
return lall(
# lhs == A * (x + b)
eq(
etuple(_dot, var("A"), etuple(pt.add, var("x"), var("b"))),
etuple(_dot, var("A"), etuple(pt.add, var("E"), var("B"))),
in_lv,
),
# rhs == A * x + A * b
eq(
etuple(
pt.add,
etuple(_dot, var("A"), var("x")),
etuple(_dot, var("A"), var("b")),
etuple(_dot, var("A"), var("E")),
etuple(_dot, var("A"), var("B")),
),
out_lv,
),
......
......@@ -631,7 +631,7 @@ def test_Dot(x, y):
x, x_test_value = x
y, y_test_value = y
g = ptm.Dot()(x, y)
g = ptm.dot(x, y)
compare_numba_and_py(
[x, y],
......
......@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
== 1
)
# For now rewrite only applies to Batched Dots
# For now we do not rewrite only the case of unbatched outer
core_outer = (not batched) and (a_shape == (3, 1)) and (b_shape == (1, 3))
rewritten_out = rewrite_graph(out)
assert rewritten_out.type.shape == out.type.shape
assert sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([rewritten_out])
if var.owner
) == (0 if batched else 1)
) == (1 if core_outer else 0)
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
......
......@@ -9,7 +9,6 @@ from numpy.testing import assert_array_almost_equal
import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
import pytensor.tensor.blas_scipy
from pytensor.compile.function import function
from pytensor.compile.io import In
from pytensor.compile.mode import Mode
......
......@@ -8,7 +8,6 @@ import pytensor.tensor as pt
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
from tests import unittest_tools
from tests.tensor.test_blas import BaseGemv, TestBlasStrides
......@@ -68,8 +67,6 @@ class TestCGer(OptimizationTestMixin):
assert CGer(False) == CGer(False)
assert CGer(False) != CGer(True)
assert CGer(True) != ScipyGer(True)
assert CGer(False) != ScipyGer(False)
assert CGer(True) != Ger(True)
assert CGer(False) != Ger(False)
......
import pickle
import numpy as np
import pytensor
from pytensor import tensor as pt
from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.math import outer
from pytensor.tensor.type import tensor
from tests.tensor.test_blas import TestBlasStrides, gemm_no_inplace
from tests.unittest_tools import OptimizationTestMixin
class TestScipyGer(OptimizationTestMixin):
def setup_method(self):
self.mode = pytensor.compile.get_default_mode()
self.mode = self.mode.including("fast_run")
self.mode = self.mode.excluding("c_blas") # c_blas trumps scipy Ops
dtype = self.dtype = "float64" # optimization isn't dtype-dependent
self.A = tensor(dtype=dtype, shape=(None, None))
self.a = tensor(dtype=dtype, shape=())
self.x = tensor(dtype=dtype, shape=(None,))
self.y = tensor(dtype=dtype, shape=(None,))
self.Aval = np.ones((2, 3), dtype=dtype)
self.xval = np.asarray([1, 2], dtype=dtype)
self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype)
def function(self, inputs, outputs):
return pytensor.function(inputs, outputs, self.mode)
def run_f(self, f):
f(self.Aval, self.xval, self.yval)
f(self.Aval[::-1, ::-1], self.xval[::-1], self.yval[::-1])
def b(self, bval):
return pt.as_tensor_variable(np.asarray(bval, dtype=self.dtype))
def test_outer(self):
f = self.function([self.x, self.y], outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=True))
def test_A_plus_outer(self):
f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) # DebugMode tests correctness
def test_A_plus_scaled_outer(self):
f = self.function(
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) # DebugMode tests correctness
def test_scaled_A_plus_scaled_outer(self):
f = self.function(
[self.A, self.x, self.y], 0.2 * self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, gemm_no_inplace)
self.run_f(f) # DebugMode tests correctness
def test_pickle(self):
out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y)
f = pytensor.function([self.A, self.a, self.x, self.y], out)
new_f = pickle.loads(pickle.dumps(f))
assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer)
assert np.allclose(
f(self.Aval, 1.0, self.xval, self.yval),
new_f(self.Aval, 1.0, self.xval, self.yval),
)
class TestBlasStridesScipy(TestBlasStrides):
mode = pytensor.compile.get_default_mode()
mode = mode.including("fast_run").excluding("gpu", "c_blas")
......@@ -1998,50 +1998,20 @@ class TestMean:
assert mean(ll).eval() == 1
def test_dot_numpy_inputs():
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a = np.ones(2)
b = np.ones(2)
res = dot(a, b)
assert isinstance(res, Variable)
assert isinstance(res.owner.op, Dot)
class TestDot:
def test_Op_dims(self):
def test_valid_ndim(self):
d0 = scalar()
d1 = vector()
d2 = matrix()
d3 = tensor3()
with pytest.raises(TypeError):
_dot(d0, d0)
with pytest.raises(TypeError):
_dot(d0, d1)
with pytest.raises(TypeError):
_dot(d0, d2)
with pytest.raises(TypeError):
_dot(d0, d3)
with pytest.raises(TypeError):
_dot(d1, d0)
_dot(d1, d1)
_dot(d1, d2)
with pytest.raises(TypeError):
_dot(d1, d3)
with pytest.raises(TypeError):
_dot(d2, d0)
_dot(d2, d1)
_dot(d2, d2)
with pytest.raises(TypeError):
_dot(d2, d3)
with pytest.raises(TypeError):
_dot(d3, d0)
with pytest.raises(TypeError):
_dot(d3, d1)
_dot(d1, d2)
with pytest.raises(TypeError):
_dot(d3, d2)
with pytest.raises(TypeError):
_dot(d3, d3)
_dot(d2, d2) # Fine
def test_grad(self):
rng = np.random.default_rng(seed=utt.fetch_seed())
......@@ -2089,6 +2059,14 @@ class TestDot:
g = grad(z.sum(), y)
assert is_super_shape(y, g)
def test_dot_numpy_inputs(self):
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a = np.ones((2, 2))
b = np.ones((2, 2))
res = dot(a, b)
assert isinstance(res, Variable)
assert isinstance(res.owner.op, Dot)
def test_matrix_vector_ops():
"""Test vecdot, matvec, and vecmat helper functions."""
......@@ -2796,7 +2774,7 @@ class TestInferShape(utt.InferShapeTester):
bdvec_val = random(4, rng=rng)
self._compile_and_check(
[advec, bdvec],
[Dot()(advec, bdvec)],
[dot(advec, bdvec)],
[advec_val, bdvec_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
)
......@@ -2808,7 +2786,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val = random(5, 3, rng=rng)
self._compile_and_check(
[admat, bdmat],
[Dot()(admat, bdmat)],
[dot(admat, bdmat)],
[admat_val, bdmat_val],
(Dot, blas.Dot22),
)
......@@ -2817,7 +2795,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val = random(4, 5, rng=rng)
self._compile_and_check(
[advec, bdmat],
[Dot()(advec, bdmat)],
[dot(advec, bdmat)],
[advec_val, bdmat_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
)
......@@ -2826,7 +2804,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(5, 4, rng=rng)
self._compile_and_check(
[admat, bdvec],
[Dot()(admat, bdvec)],
[dot(admat, bdvec)],
[admat_val, bdvec_val],
(Dot, blas.Dot22, blas.Gemv, blas_c.CGemv),
)
......
......@@ -333,7 +333,7 @@ def test_debugprint():
def test_debugprint_id_type():
a_at = dvector()
a_at = dmatrix()
b_at = dmatrix()
d_at = b_at.dot(a_at)
......@@ -344,10 +344,10 @@ def test_debugprint_id_type():
s = s.getvalue()
exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}]
├─ Dot [id {d_at.auto_name}]
│ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}]
│ └─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}]
└─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}]
│ └─ <Matrix(float64, shape=(?, ?))> [id {a_at.auto_name}]
└─ <Matrix(float64, shape=(?, ?))> [id {a_at.auto_name}]
"""
assert [l.strip() for l in s.split("\n")] == [
......
......@@ -312,5 +312,7 @@ def test_dot_errors():
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
# Doesn't fail until the rewrite
with pytest.raises(ValueError, match="not aligned"):
with pytest.raises(
ValueError, match="Input operand 1 has a mismatch in its core dimension 0"
):
fn(x_test, y_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论