提交 ee7a8bd5 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Use `jax.numpy.copy` directly

上级 f52c5678
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from pytensor.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...@@ -69,7 +69,7 @@ def jax_funcify_DimShuffle(op, **kwargs): ...@@ -69,7 +69,7 @@ def jax_funcify_DimShuffle(op, **kwargs):
res = jnp.reshape(res, shape) res = jnp.reshape(res, shape)
if not op.inplace: if not op.inplace:
res = jnp_safe_copy(res) res = jnp.copy(res)
return res return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论