提交 75485c13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify BatchedDot implementation

The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
上级 18f245fa
...@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b): def batched_dot(a, b):
if a.shape[0] != b.shape[0]: if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension") raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2: return jnp.matmul(a, b)
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return batched_dot return batched_dot
......
...@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs): ...@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
@numba_njit @numba_njit
def batched_dot(x, y): def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:] shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype) z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]): for i in range(z0.shape[0]):
......
...@@ -98,10 +98,11 @@ from pytensor.link.c.params_type import ParamsType ...@@ -98,10 +98,11 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.printing import FunctionPrinter, pprint from pytensor.printing import FunctionPrinter, pprint
from pytensor.scalar import bool as bool_t from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
from pytensor.utils import memoize from pytensor.utils import memoize
...@@ -1637,48 +1638,53 @@ _dot22scalar = Dot22Scalar() ...@@ -1637,48 +1638,53 @@ _dot22scalar = Dot22Scalar()
class BatchedDot(COp): class BatchedDot(COp):
""" """
Computes the batched dot product of two variables: Computes a batch matrix-matrix dot with tensor3 variables
batched_dot(a, b)[i] = dot(a[i], b[i]) batched_dot(a, b)[i] = dot(a[i], b[i])
""" """
__props__ = () __props__ = ()
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
def make_node(self, *inputs): def make_node(self, x, y):
inputs = list(map(at.as_tensor_variable, inputs)) x = at.as_tensor_variable(x)
y = at.as_tensor_variable(y)
if any(not isinstance(i.type, DenseTensorType) for i in inputs): if not (
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
):
raise NotImplementedError("Only dense tensor types are supported") raise NotImplementedError("Only dense tensor types are supported")
if len(inputs) != 2: if not (x.type.ndim == 3 and y.type.ndim == 3):
raise TypeError(f"Two arguments required, but {len(inputs)} given.")
if inputs[0].ndim not in (2, 3):
raise TypeError(
"Input 0 (0-indexed)"
f" must have ndim of 2 or 3, {int(inputs[0].ndim)} given. Consider"
" calling batched_dot instead."
)
if inputs[1].ndim not in (2, 3):
raise TypeError( raise TypeError(
"Input 1 (0-indexed)" f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. "
f" must have ndim of 2 or 3, {int(inputs[1].ndim)} given. Consider" "Consider calling batched_dot instead."
" calling batched_dot instead."
) )
dtype = pytensor.scalar.upcast(*[input.type.dtype for input in inputs]) def extract_static_dim(dim_x, dim_y):
# upcast inputs to common dtype if needed dims = {dim_x, dim_y} - {None}
upcasted_inputs = [at.cast(input, dtype) for input in inputs] if len(dims) > 1:
out_shape = ( # BatchedDot doesn't allow broadcasting
( raise ValueError(
1 f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}"
if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1
else None,
)
+ inputs[0].type.shape[1:-1]
+ inputs[1].type.shape[2:]
) )
out_shape = tuple(1 if s == 1 else None for s in out_shape) elif not dims:
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)]) return None
else:
return dims.pop()
x_batch_dim, x_row_dim, x_sum_dim = x.type.shape
y_batch_dim, y_sum_dim, y_col_dim = y.type.shape
batch_dim = extract_static_dim(x_batch_dim, y_batch_dim)
# Raise if static sum dimensions do not match
_ = extract_static_dim(x_sum_dim, y_sum_dim)
out_shape = (batch_dim, x_row_dim, y_col_dim)
# Change dtype if needed
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
x, y = at.cast(x, dtype), at.cast(y, dtype)
out = tensor(dtype=dtype, shape=out_shape)
return Apply(self, [x, y], [out])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y = inp x, y = inp
...@@ -1690,11 +1696,7 @@ class BatchedDot(COp): ...@@ -1690,11 +1696,7 @@ class BatchedDot(COp):
f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]." f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]."
) )
shape = self.infer_shape(None, node, [i.shape for i in inp])[0] z[0] = np.matmul(x, y)
dtype = node.outputs[0].dtype
z0 = z[0] = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
batch_gemm_defn = """ batch_gemm_defn = """
...@@ -1792,14 +1794,6 @@ class BatchedDot(COp): ...@@ -1792,14 +1794,6 @@ class BatchedDot(COp):
def c_header_dirs(self, **kwargs): def c_header_dirs(self, **kwargs):
return ldflags(libs=False, include_dir=True) return ldflags(libs=False, include_dir=True)
def c_code_cleanup(self, node, name, inputs, outputs, sub):
return """
// clean up views
Py_XDECREF(xs); xs = 0;
Py_XDECREF(ys); ys = 0;
Py_XDECREF(zs); zs = 0;
"""
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
_x, _y = inp _x, _y = inp
(_z,) = out (_z,) = out
...@@ -1832,12 +1826,11 @@ class BatchedDot(COp): ...@@ -1832,12 +1826,11 @@ class BatchedDot(COp):
) )
# generate code to allocate output based on runtime input shapes # generate code to allocate output based on runtime input shapes
z_dims = [f"PyArray_DIMS({_x})[0]"] z_dims = [
if x_ndim == 3: f"PyArray_DIMS({_x})[0]",
z_dims.append(f"PyArray_DIMS({_x})[1]") f"PyArray_DIMS({_x})[1]",
if y_ndim == 3: f"PyArray_DIMS({_y})[2]",
z_dims.append(f"PyArray_DIMS({_y})[2]") ]
assert len(z_dims) == z_ndim
z_shape_correct = " && ".join( z_shape_correct = " && ".join(
"PyArray_DIMS(%s)[%i] == %s" % (_z, i, dim) for i, dim in enumerate(z_dims) "PyArray_DIMS(%s)[%i] == %s" % (_z, i, dim) for i, dim in enumerate(z_dims)
...@@ -1880,76 +1873,26 @@ class BatchedDot(COp): ...@@ -1880,76 +1873,26 @@ class BatchedDot(COp):
) )
contiguate = "\n".join(contiguate) contiguate = "\n".join(contiguate)
def c_dimshuffle(newname, oldname, shape):
_fail = fail
_shape = ", ".join(
"1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
for axis in shape
)
return (
"""{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {dims, 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
if (!%(newname)s)
%(_fail)s
// make sure we didn't accidentally copy
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
}"""
% locals()
)
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
# xs, ys and zs will point to these views, or to the original array if
# it was already tensor3.
# in the latter case, we artificially increase the reference count of
# the original array so that the c_code_cleanup method can decref them
# all indiscriminately.
upcast = []
if x_ndim == 3:
upcast.append("xs = %(_x)s; Py_XINCREF(xs);")
elif x_ndim == 2:
upcast.append(c_dimshuffle("xs", _x, (0, None, 1)))
if y_ndim == 3:
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
elif y_ndim == 2:
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
if z_ndim == 3:
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
else:
upcast.append(
c_dimshuffle(
"zs",
_z,
(0, None if x_ndim == 2 else 1, None if y_ndim == 2 else 1),
)
)
upcast = "\n".join(upcast) % locals()
return ( return (
""" """
int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
// xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s if (PyArray_NDIM(%(_x)s) != 3) {
PyArrayObject *xs = 0, *ys = 0, *zs = 0;
if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(x) != %(x_ndim)s. rank(x) is %%d.", "rank(x) != 3. rank(x) is %%d.",
PyArray_NDIM(%(_x)s)); PyArray_NDIM(%(_x)s));
%(fail)s; %(fail)s;
} }
if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) { if (PyArray_NDIM(%(_y)s) != 3) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(y) != %(y_ndim)s. rank(y) is %%d.", "rank(y) != 3. rank(y) is %%d.",
PyArray_NDIM(%(_y)s)); PyArray_NDIM(%(_y)s));
%(fail)s; %(fail)s;
} }
if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) { if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(z) != %(z_ndim)s. rank(z) is %%d.", "rank(z) != 3. rank(z) is %%d.",
PyArray_NDIM(%(_z)s)); PyArray_NDIM(%(_z)s));
%(fail)s; %(fail)s;
} }
...@@ -1958,36 +1901,32 @@ class BatchedDot(COp): ...@@ -1958,36 +1901,32 @@ class BatchedDot(COp):
%(allocate)s %(allocate)s
// reallocate any noncontiguous arrays or arrays with invalid strides // reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s %(contiguate)s
// add dims to make sure everything is tensor3
%(upcast)s
// from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original %(_x)s, %(_y)s and %(_z)s arrays.
if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(xs)->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(ys)->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(zs)->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num) if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num)) ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
switch (type_num) switch (type_num)
{ {
case NPY_FLOAT: case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) { if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
%(fail)s; %(fail)s;
} }
break; break;
case NPY_DOUBLE: case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) { if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
%(fail)s; %(fail)s;
} }
break; break;
...@@ -1999,30 +1938,12 @@ class BatchedDot(COp): ...@@ -1999,30 +1938,12 @@ class BatchedDot(COp):
def c_code_cache_version(self): def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version from pytensor.tensor.blas_headers import blas_header_version
return (4, blas_header_version()) return (5, blas_header_version())
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
(gz,) = grads (gz,) = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is a vector, so x is a matrix and y is a matrix
if gdim == 1:
xgrad = gz.dimshuffle(0, "x") * y
ygrad = gz.dimshuffle(0, "x") * x
# x is a matrix, y is a tensor3, grad is a matrix
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1)
# x is a tensor3, y is a matrix, grad is a matrix
elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is a tensor3, y is a tensor3, grad is a tensor3
elif xdim == ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
...@@ -2105,6 +2026,7 @@ class BatchedDot(COp): ...@@ -2105,6 +2026,7 @@ class BatchedDot(COp):
+ " to BatchedDot.R_op should have the same shape, but " + " to BatchedDot.R_op should have the same shape, but "
f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively" f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively"
) )
if eval_points[0]: if eval_points[0]:
t1 = self(eval_points[0], inputs[1]) t1 = self(eval_points[0], inputs[1])
if eval_points[1]: if eval_points[1]:
...@@ -2118,9 +2040,6 @@ class BatchedDot(COp): ...@@ -2118,9 +2040,6 @@ class BatchedDot(COp):
return [t2] return [t2]
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
for shape_ in shapes:
if len(shape_) not in (2, 3):
raise NotImplementedError()
xshp, yshp = shapes xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]] return [xshp[:-1] + yshp[2:]]
...@@ -2157,14 +2076,24 @@ def batched_dot(a, b): ...@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
elif b.ndim == 0: elif b.ndim == 0:
raise TypeError("b must have at least one (batch) axis") raise TypeError("b must have at least one (batch) axis")
elif a.ndim == 1: elif a.ndim == 1:
return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b return shape_padright(a, (b.ndim - 1)) * b
elif b.ndim == 1: elif b.ndim == 1:
return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1))) return a * shape_padright(b, (a.ndim - 1))
elif a.ndim > 3 or b.ndim > 3: elif a.ndim > 3 or b.ndim > 3:
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]]) return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
else: else:
# avoid circular import # If either a or b is a batched vector, expand dims and later squeeze them
return _batched_dot(a, b) expanded_axis = []
if a.ndim == 2:
a = expand_dims(a, axis=1)
expanded_axis.append(1)
if b.ndim == 2:
b = expand_dims(b, axis=2)
expanded_axis.append(2)
out = _batched_dot(a, b)
if expanded_axis:
out = out.squeeze(axis=expanded_axis)
return out
def batched_tensordot(x, y, axes=2): def batched_tensordot(x, y, axes=2):
......
...@@ -43,15 +43,6 @@ def test_jax_BatchedDot(): ...@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
with pytest.raises(TypeError): with pytest.raises(TypeError):
pytensor_jax_fn(*inputs) pytensor_jax_fn(*inputs)
# matrix . matrix
a = matrix("a")
a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3))
b = matrix("b")
b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3))
out = at_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_basic_multiout(): def test_jax_basic_multiout():
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
......
...@@ -843,23 +843,23 @@ def test_Softplus(x, exc): ...@@ -843,23 +843,23 @@ def test_Softplus(x, exc):
[ [
( (
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
None, None,
), ),
( (
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( set_test_value(
at.lmatrix(), at.ltensor3(),
rng.poisson(size=(3, 3)).astype("int64"), rng.poisson(size=(2, 3, 3)).astype("int64"),
), ),
None, None,
), ),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论