提交 0b2cf52f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Short-circuit `as_scalar` common cases faster

上级 5f8cee6b
......@@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
if isinstance(x, ScalarVariable):
return x
if isinstance(x, Variable):
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType
if isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
if isinstance(x, Apply):
# FIXME: Why do we support calling this with Apply?
# Also, if we do, why can't we support multiple outputs?
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x, ScalarVariable):
return x
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
return as_scalar(x.outputs[0])
return constant(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论