Unverified 提交 2c03ecfb authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Defer the use of nfunc_spec in JAX scalar dispatch

上级 cb8b8ac8
......@@ -62,8 +62,6 @@ def check_if_inputs_scalars(node):
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, node, **kwargs):
func_name = op.nfunc_spec[0]
# We dispatch some PyTensor operators to Python operators
# whenever the inputs are all scalars.
are_inputs_scalars = check_if_inputs_scalars(node)
......@@ -71,7 +69,7 @@ def jax_funcify_ScalarOp(op, node, **kwargs):
elemwise = elemwise_scalar(op)
if elemwise is not None:
return elemwise
func_name = op.nfunc_spec[0]
if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论