提交 6aeed97c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify typify of Generators in JAXLinker

上级 9f80bdcd
......@@ -117,10 +117,8 @@ class JAXLinker(JITLinker):
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], Generator):
new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
sinput[0] = new_value
# Neet to convert Generator into JAX PRNGkey
sinput[0] = jax_typify(sinput[0])
thunk_inputs.append(sinput)
return thunk_inputs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论