提交 47f76375 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Dispatch some `Op`s to Python operators when scalar inputs

上级 efb4996f
......@@ -5,14 +5,57 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scalar import Softplus
from pytensor.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from pytensor.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
IntDiv,
Mod,
Mul,
ScalarOp,
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
def check_if_inputs_scalars(node):
"""Check whether all the inputs of an `Elemwise` are scalar values.
`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
while the corresponding Python operators return concrete values when passed
concrete values. In order to be able to compile the largest number of graphs
possible we need to preserve concrete values whenever we can. We thus need
to dispatch differently the PyTensor operators depending on whether the inputs
are scalars.
"""
ndims_input = [inp.type.ndim for inp in node.inputs]
are_inputs_scalars = True
for ndim in ndims_input:
try:
if ndim > 0:
are_inputs_scalars = False
except TypeError:
are_inputs_scalars = False
return are_inputs_scalars
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, **kwargs):
def jax_funcify_ScalarOp(op, node, **kwargs):
func_name = op.nfunc_spec[0]
# We dispatch some PyTensor operators to Python operators
# whenever the inputs are all scalars.
are_inputs_scalars = check_if_inputs_scalars(node)
if are_inputs_scalars:
elemwise = elemwise_scalar(op)
if elemwise is not None:
return elemwise
if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
......@@ -38,6 +81,54 @@ def jax_funcify_ScalarOp(op, **kwargs):
return jnp_func
@functools.singledispatch
def elemwise_scalar(op):
return None
@elemwise_scalar.register(Add)
def elemwise_scalar_add(op):
def elemwise(*inputs):
return sum(inputs)
return elemwise
@elemwise_scalar.register(Mul)
def elemwise_scalar_mul(op):
import operator
from functools import reduce
def elemwise(*inputs):
return reduce(operator.mul, inputs, 1)
return elemwise
@elemwise_scalar.register(Sub)
def elemwise_scalar_sub(op):
def elemwise(x, y):
return x - y
return elemwise
@elemwise_scalar.register(IntDiv)
def elemwise_scalar_intdiv(op):
def elemwise(x, y):
return x // y
return elemwise
@elemwise_scalar.register(Mod)
def elemwise_scalar_mod(op):
def elemwise(x, y):
return x % y
return elemwise
@jax_funcify.register(Cast)
def jax_funcify_Cast(op, **kwargs):
def cast(x):
......
......@@ -161,6 +161,42 @@ def test_jax_variadic_Scalar():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_add_scalars():
x = at.matrix("x")
size = x.shape[0] + x.shape[0] + x.shape[1]
out = at.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_mul_scalars():
x = at.matrix("x")
size = x.shape[0] * x.shape[0] * x.shape[1]
out = at.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_div_scalars():
x = at.matrix("x")
size = x.shape[0] // x.shape[1]
out = at.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_mod_scalars():
x = at.matrix("x")
size = x.shape[0] % x.shape[1]
out = at.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_jax_multioutput():
x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论