提交 a0a494ab authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize ScalarFromTensor

上级 6dd61726
...@@ -710,6 +710,17 @@ class ScalarFromTensor(COp): ...@@ -710,6 +710,17 @@ class ScalarFromTensor(COp):
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
@_vectorize_node.register(ScalarFromTensor)
def vectorize_scalar_from_tensor(op, node, batch_x):
if batch_x.ndim == 0:
return scalar_from_tensor(batch_x).owner
if batch_x.owner is not None:
return batch_x.owner
# Needed until we fix https://github.com/pymc-devs/pytensor/issues/902
return batch_x.copy().owner
# to be removed as we get the epydoc routine-documenting thing going # to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924 # -JB 20080924
def _conversion(real_value: Op, name: str) -> Op: def _conversion(real_value: Op, name: str) -> Op:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论