提交 0b94be01 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reduce overhead of Scalar python implementation

上级 0b07727b
......@@ -36,7 +36,6 @@ from pytensor.printing import pprint
from pytensor.utils import (
apply_across_args,
difference,
from_return_values,
to_return_values,
)
......@@ -1081,6 +1080,16 @@ def real_out(type):
return (type,)
def _cast_to_promised_scalar_dtype(x, dtype):
try:
return x.astype(dtype)
except AttributeError:
if dtype == "bool":
return np.bool_(x)
else:
return getattr(np, dtype)(x)
class ScalarOp(COp):
nin = -1
nout = 1
......@@ -1134,28 +1143,18 @@ class ScalarOp(COp):
else:
raise NotImplementedError(f"Cannot calculate the output types for {self}")
@staticmethod
def _cast_scalar(x, dtype):
if hasattr(x, "astype"):
return x.astype(dtype)
elif dtype == "bool":
return np.bool_(x)
else:
return getattr(np, dtype)(x)
def perform(self, node, inputs, output_storage):
if self.nout == 1:
dtype = node.outputs[0].dtype
output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype)
output_storage[0][0] = _cast_to_promised_scalar_dtype(
self.impl(*inputs),
node.outputs[0].dtype,
)
else:
variables = from_return_values(self.impl(*inputs))
assert len(variables) == len(output_storage)
# strict=False because we are in a hot loop
for out, storage, variable in zip(
node.outputs, output_storage, variables, strict=False
node.outputs, output_storage, self.impl(*inputs), strict=False
):
dtype = out.dtype
storage[0] = self._cast_scalar(variable, dtype)
storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype)
def impl(self, *inputs):
raise MethodNotDefined("impl", type(self), self.__class__.__name__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论