提交 6aeed97c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify typify of Generators in JAXLinker

上级 9f80bdcd
...@@ -117,10 +117,8 @@ class JAXLinker(JITLinker): ...@@ -117,10 +117,8 @@ class JAXLinker(JITLinker):
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
sinput = storage_map[n] sinput = storage_map[n]
if isinstance(sinput[0], Generator): if isinstance(sinput[0], Generator):
new_value = jax_typify( # Neet to convert Generator into JAX PRNGkey
sinput[0], dtype=getattr(sinput[0], "dtype", None) sinput[0] = jax_typify(sinput[0])
)
sinput[0] = new_value
thunk_inputs.append(sinput) thunk_inputs.append(sinput)
return thunk_inputs return thunk_inputs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论