提交 4fa6c415 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up comments and annotations in aesara.scan.utils.safe_new

上级 47cc3474
...@@ -34,11 +34,14 @@ if TYPE_CHECKING: ...@@ -34,11 +34,14 @@ if TYPE_CHECKING:
_logger = logging.getLogger("aesara.scan.utils") _logger = logging.getLogger("aesara.scan.utils")
def safe_new(x, tag="", dtype=None): def safe_new(
""" x: Variable, tag: str = "", dtype: Optional[Union[str, np.dtype]] = None
Internal function that constructs a new variable from x with the same ) -> Variable:
"""Clone variables.
Internal function that constructs a new variable from `x` with the same
type, but with a different name (old name + tag). This function is used type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of by `gradient`, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original the inner graph such that there is no interference between the original
graph and the newly constructed graph. graph and the newly constructed graph.
...@@ -51,14 +54,14 @@ def safe_new(x, tag="", dtype=None): ...@@ -51,14 +54,14 @@ def safe_new(x, tag="", dtype=None):
if isinstance(x, Constant): if isinstance(x, Constant):
if dtype and x.dtype != dtype: if dtype and x.dtype != dtype:
casted_x = x.astype(dtype) casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name) nwx = type(x)(casted_x.type, x.data, x.name)
nwx.tag = copy.copy(x.tag) nwx.tag = copy.copy(x.tag)
return nwx return nwx
else: else:
return x.clone() return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, `as_tensor_variable` will convert the `Scalar` into a
# TensorScalar that will require a ScalarFromTensor op, # `TensorScalar` that will require a `ScalarFromTensor` `Op`, making the
# making the pushout optimization fail # push-out optimization fail
elif isinstance(x, aes.ScalarVariable): elif isinstance(x, aes.ScalarVariable):
if dtype: if dtype:
nw_x = aes.get_scalar_type(dtype=dtype)() nw_x = aes.get_scalar_type(dtype=dtype)()
...@@ -82,13 +85,13 @@ def safe_new(x, tag="", dtype=None): ...@@ -82,13 +85,13 @@ def safe_new(x, tag="", dtype=None):
# This could happen for example for random states # This could happen for example for random states
pass pass
# Cast x if needed. If x has a test value, this will also cast it. # Cast `x` if needed. If `x` has a test value, this will also cast it.
if dtype and x.dtype != dtype: if dtype and x.dtype != dtype:
x = x.astype(dtype) x = x.astype(dtype)
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the `compute_test_value` option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may # between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though. # not be the most efficient memory-wise, though.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论