提交 4fa6c415 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up comments and annotations in aesara.scan.utils.safe_new

上级 47cc3474
......@@ -34,11 +34,14 @@ if TYPE_CHECKING:
_logger = logging.getLogger("aesara.scan.utils")
def safe_new(x, tag="", dtype=None):
"""
Internal function that constructs a new variable from x with the same
def safe_new(
x: Variable, tag: str = "", dtype: Optional[Union[str, np.dtype]] = None
) -> Variable:
"""Clone variables.
Internal function that constructs a new variable from `x` with the same
type, but with a different name (old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of
by `gradient`, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
......@@ -51,14 +54,14 @@ def safe_new(x, tag="", dtype=None):
if isinstance(x, Constant):
if dtype and x.dtype != dtype:
casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name)
nwx = type(x)(casted_x.type, x.data, x.name)
nwx.tag = copy.copy(x.tag)
return nwx
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail
# Note, `as_tensor_variable` will convert the `Scalar` into a
# `TensorScalar` that will require a `ScalarFromTensor` `Op`, making the
# push-out optimization fail
elif isinstance(x, aes.ScalarVariable):
if dtype:
nw_x = aes.get_scalar_type(dtype=dtype)()
......@@ -82,13 +85,13 @@ def safe_new(x, tag="", dtype=None):
# This could happen for example for random states
pass
# Cast x if needed. If x has a test value, this will also cast it.
# Cast `x` if needed. If `x` has a test value, this will also cast it.
if dtype and x.dtype != dtype:
x = x.astype(dtype)
nw_x = x.type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# Preserve test values so that the `compute_test_value` option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
# not be the most efficient memory-wise, though.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论