提交 0b2cf52f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Short-circuit `as_scalar` common cases faster

上级 5f8cee6b
...@@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: ...@@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
from pytensor.tensor.basic import scalar_from_tensor if isinstance(x, ScalarVariable):
from pytensor.tensor.type import TensorType return x
if isinstance(x, Variable):
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType
if isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
if isinstance(x, Apply): if isinstance(x, Apply):
# FIXME: Why do we support calling this with Apply?
# Also, if we do, why can't we support multiple outputs?
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError( raise ValueError(
"It is ambiguous which output of a multi-output" "It is ambiguous which output of a multi-output"
" Op has to be fetched.", " Op has to be fetched.",
x, x,
) )
else: return as_scalar(x.outputs[0])
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x, ScalarVariable):
return x
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
return constant(x) return constant(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论