提交 5fdc130f authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

add variable_depends_on

上级 8ad33179
...@@ -1603,6 +1603,28 @@ def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) ...@@ -1603,6 +1603,28 @@ def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]])
return False return False
def variable_depends_on(
variable: Variable, depends_on: Union[Variable, Collection[Variable]]
) -> bool:
"""Determine if any `depends_on` is in the graph given by ``variable``.
Parameters
----------
variable: Variable
Node to check
depends_on: Collection[Variable]
Nodes to check dependency on
Returns
-------
bool
"""
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
return any(interim in depends_on for interim in ancestors([variable]))
def equal_computations( def equal_computations(
xs: List[Union[np.ndarray, Variable]], xs: List[Union[np.ndarray, Variable]],
ys: List[Union[np.ndarray, Variable]], ys: List[Union[np.ndarray, Variable]],
......
...@@ -23,6 +23,7 @@ from pytensor.graph.basic import ( ...@@ -23,6 +23,7 @@ from pytensor.graph.basic import (
io_toposort, io_toposort,
list_of_nodes, list_of_nodes,
orphans_between, orphans_between,
variable_depends_on,
vars_between, vars_between,
walk, walk,
) )
...@@ -675,3 +676,22 @@ def test_NominalVariable_create_variable_type(): ...@@ -675,3 +676,22 @@ def test_NominalVariable_create_variable_type():
assert type(ntv_unpkld) is type(ntv) assert type(ntv_unpkld) is type(ntv)
assert ntv_unpkld.equals(ntv) assert ntv_unpkld.equals(ntv)
assert ntv_unpkld is ntv assert ntv_unpkld is ntv
def test_variable_depends_on():
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
x2 = MyOp(x)
x2.name = "x2"
y2 = MyOp(y)
y2.name = "y2"
o = MyOp(x2, y)
assert variable_depends_on(o, x)
assert variable_depends_on(o, [x])
assert not variable_depends_on(o, [y2])
assert variable_depends_on(o, [y2, x])
assert not variable_depends_on(y, [y2])
assert variable_depends_on(y, [y])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论