提交 d58d482a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph methods add_output, remove_node, remove_input, remove_output

上级 124ed5df
...@@ -126,13 +126,13 @@ class FunctionGraph(MetaObject): ...@@ -126,13 +126,13 @@ class FunctionGraph(MetaObject):
# outputs are cached in this field # outputs are cached in this field
self.apply_nodes: Set[Apply] = set() self.apply_nodes: Set[Apply] = set()
# Ditto for variable nodes. # It includes inputs, outputs, and all intermediate variables
# It must contain all fgraph.inputs and all apply_nodes # connecting the inputs and outputs. It also contains irrelevant
# outputs even if they aren't used in the graph. # outputs the nodes in `self.apply_nodes`.
self.variables: Set[Variable] = set() self.variables: Set[Variable] = set()
self.inputs: List[Variable] = [] self.inputs: List[Variable] = []
self.outputs: List[Variable] = list(outputs) self.outputs: List[Variable] = []
self.clients: Dict[Variable, List[ClientType]] = {} self.clients: Dict[Variable, List[ClientType]] = {}
for f in features: for f in features:
...@@ -152,13 +152,19 @@ class FunctionGraph(MetaObject): ...@@ -152,13 +152,19 @@ class FunctionGraph(MetaObject):
self.add_input(in_var, check=False) self.add_input(in_var, check=False)
for output in outputs: for output in outputs:
self.import_var(output, reason="init") self.add_output(output, reason="init")
for i, output in enumerate(outputs):
self.clients[output].append(("output", i))
self.profile = None self.profile = None
self.update_mapping = update_mapping self.update_mapping = update_mapping
def add_output(
self, var: Variable, reason: Optional[str] = None, import_missing: bool = False
):
"""Add a new variable as an output to this `FunctionGraph`."""
self.outputs.append(var)
self.import_var(var, reason=reason, import_missing=import_missing)
self.clients[var].append(("output", len(self.outputs) - 1))
def add_input(self, var: Variable, check: bool = True) -> None: def add_input(self, var: Variable, check: bool = True) -> None:
"""Add a new variable as an input to this `FunctionGraph`. """Add a new variable as an input to this `FunctionGraph`.
...@@ -172,7 +178,6 @@ class FunctionGraph(MetaObject): ...@@ -172,7 +178,6 @@ class FunctionGraph(MetaObject):
self.inputs.append(var) self.inputs.append(var)
self.setup_var(var) self.setup_var(var)
self.variables.add(var)
def setup_var(self, var: Variable) -> None: def setup_var(self, var: Variable) -> None:
"""Set up a variable so it belongs to this `FunctionGraph`. """Set up a variable so it belongs to this `FunctionGraph`.
...@@ -210,6 +215,7 @@ class FunctionGraph(MetaObject): ...@@ -210,6 +215,7 @@ class FunctionGraph(MetaObject):
var: Variable, var: Variable,
client_to_remove: ClientType, client_to_remove: ClientType,
reason: Optional[str] = None, reason: Optional[str] = None,
remove_if_empty: bool = False,
) -> None: ) -> None:
"""Recursively remove clients of a variable. """Recursively remove clients of a variable.
...@@ -222,11 +228,14 @@ class FunctionGraph(MetaObject): ...@@ -222,11 +228,14 @@ class FunctionGraph(MetaObject):
Parameters Parameters
---------- ----------
var : Variable var
The clients of `var` that will be removed. The clients of `var` that will be removed.
client_to_remove : pair of (Apply, int) client_to_remove
A ``(node, i)`` pair such that ``node.inputs[i]`` will no longer be A ``(node, i)`` pair such that ``node.inputs[i]`` will no longer be
`var` in this `FunctionGraph`. `var` in this `FunctionGraph`.
remove_if_empty
When ``True``, if `var`'s `Apply` node is removed, remove the
entry for `var` in `self.clients`.
""" """
...@@ -250,8 +259,6 @@ class FunctionGraph(MetaObject): ...@@ -250,8 +259,6 @@ class FunctionGraph(MetaObject):
# Now, `var` has no more clients, so check if we need to remove it # Now, `var` has no more clients, so check if we need to remove it
# and its `Apply` node # and its `Apply` node
if not var.owner: if not var.owner:
# The `var` is a `Constant` or an input without a client, so we
# remove it
self.variables.remove(var) self.variables.remove(var)
else: else:
apply_node = var.owner apply_node = var.owner
...@@ -274,12 +281,15 @@ class FunctionGraph(MetaObject): ...@@ -274,12 +281,15 @@ class FunctionGraph(MetaObject):
for i, in_var in enumerate(apply_node.inputs): for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i))) removal_stack.append((in_var, (apply_node, i)))
if remove_if_empty:
del self.clients[var]
def import_var( def import_var(
self, var: Variable, reason: Optional[str] = None, import_missing: bool = False self, var: Variable, reason: Optional[str] = None, import_missing: bool = False
) -> None: ) -> None:
"""Import variables into this `FunctionGraph`. """Import a `Variable` into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node. This will import the `var`'s `Apply` node and inputs.
Parameters Parameters
---------- ----------
...@@ -517,6 +527,147 @@ class FunctionGraph(MetaObject): ...@@ -517,6 +527,147 @@ class FunctionGraph(MetaObject):
for var, new_var in pairs: for var, new_var in pairs:
self.replace(var, new_var, **kwargs) self.replace(var, new_var, **kwargs)
def _remove_output(self, idx: int):
"""Remove the output at index `idx` and update the indices in the clients entries.
`FunctionGraph.clients` contains entries like ``("output", i)`` under
each output variable in `FunctionGraph.outputs`. The ``i`` values
correspond to each output's location within the `FunctionGraph.outputs`
list, so, when an output is removed from the graph, all these entries
need to be updated. This method performs those updates.
TODO: We could track these entries in a new instance attribute and make
them lists, then each could be updated in-place very easily. This
seems fine, because the `FunctionGraph.clients` ``dict`` and list in
which they're contained are already being updated in-place.
"""
old_idx_mappings = tuple((out, i) for i, out in enumerate(self.outputs))
self.outputs.pop(idx)
new_idx = 0
for (out, old_idx) in old_idx_mappings:
if old_idx == idx:
continue
out_clients = self.clients[out]
arrow: ClientType = ("output", old_idx)
arrow_idx = out_clients.index(arrow)
out_clients[arrow_idx] = ("output", new_idx)
new_idx += 1
def remove_node(self, node: Apply, reason: Optional[str] = None):
"""Remove an `Apply` node from the `FunctionGraph`.
This will remove everything that depends on the outputs of `node`, as
well as any "orphaned" variables and nodes created by `node`'s removal.
"""
if node not in self.apply_nodes:
return
self.apply_nodes.remove(node)
if not hasattr(node.tag, "removed_by"):
node.tag.removed_by = []
node.tag.removed_by.append(str(reason))
# Remove the outputs of the node (i.e. everything "below" it)
for out in node.outputs:
self.variables.remove(out)
out_clients = self.clients.get(out, ())
while out_clients:
out_client, out_idx = out_clients.pop()
if out_client == "output":
self._remove_output(out_idx)
# TODO: We could short-circuit all of the graph walking and
# clear everything at once when all the outputs are gone.
# if not self.outputs:
# self.clients = {inp: [] for inp in self.inputs}
# self.variables = set()
# while self.apply_nodes:
# node = self.apply_nodes.pop()
# if not hasattr(node.tag, "removed_by"):
# node.tag.removed_by = []
#
# node.tag.removed_by.append(str(reason))
#
# self.execute_callbacks("on_prune", node, reason)
else:
assert isinstance(out_client, Apply)
self.remove_node(out_client, reason=reason)
if out in self.clients:
del self.clients[out]
# Remove all the arrows pointing to this `node`, and any orphaned
# variables created by removing those arrows
for inp_idx, inp in enumerate(node.inputs):
inp_clients: List[ClientType] = self.clients.get(inp, [])
arrow = (node, inp_idx)
if arrow not in inp_clients:
continue
inp_clients.remove(arrow)
if not inp_clients and inp not in self.outputs:
if inp.owner:
# If this input has no clients (after removing this arrow),
# is not an input (i.e. it has a non-`None` owner) or an
# output to the `FunctionGraph`, then it's an orphan
# We need to check whether or not this orphaned input's
# node is still needed in the graph
inp_node = inp.owner
if not any(
out in self.variables
for out in inp_node.outputs
if out is not inp
):
self.remove_node(inp_node, reason=reason)
else:
# This is an unused input
self.variables.remove(inp)
# The callbacks be triggered after everything has been removed so that
# the `FunctionGraph` state subscribers see is valid.
self.execute_callbacks("on_prune", node, reason)
def remove_input(self, input_idx: int, reason: Optional[str] = None):
"""Remove the input at index `input_idx`."""
var = self.inputs.pop(input_idx)
for client, idx in list(self.clients[var]):
if client == "output":
out_var = self.outputs[idx]
out_node = out_var.owner
if out_node is None:
assert out_var in self.inputs
self.outputs.pop(idx)
continue
client_node = out_node
else:
assert isinstance(client, Apply)
client_node = client
self.remove_node(client_node, reason=reason)
def remove_output(self, output_idx: int, reason: Optional[str] = None):
"""Remove the output at index `input_idx`."""
var = self.outputs[output_idx]
self._remove_output(output_idx)
self.remove_client(
var, ("output", output_idx), reason=reason, remove_if_empty=True
)
def attach_feature(self, feature: Feature) -> None: def attach_feature(self, feature: Feature) -> None:
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback.""" """Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
# Filter out literally identical `Feature`s # Filter out literally identical `Feature`s
...@@ -668,9 +819,7 @@ class FunctionGraph(MetaObject): ...@@ -668,9 +819,7 @@ class FunctionGraph(MetaObject):
nodes_missing = nodes.difference(self.apply_nodes) nodes_missing = nodes.difference(self.apply_nodes)
nodes_excess = self.apply_nodes.difference(nodes) nodes_excess = self.apply_nodes.difference(nodes)
raise Exception( raise Exception(
"The nodes are inappropriately cached. missing, in excess: ", f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}"
nodes_missing,
nodes_excess,
) )
for node in nodes: for node in nodes:
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
...@@ -684,9 +833,7 @@ class FunctionGraph(MetaObject): ...@@ -684,9 +833,7 @@ class FunctionGraph(MetaObject):
vars_missing = variables.difference(self.variables) vars_missing = variables.difference(self.variables)
vars_excess = self.variables.difference(variables) vars_excess = self.variables.difference(variables)
raise Exception( raise Exception(
"The variables are inappropriately cached. missing, in excess: ", f"The following variables are inappropriately cached:\nmissing: {vars_missing}\nin excess: {vars_excess}"
vars_missing,
vars_excess,
) )
for variable in variables: for variable in variables:
if ( if (
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from tests.graph.utils import MyConstant, MyVariable, MyVariable2, op1, op2, op3 from tests.graph.utils import MyConstant, MyOp, MyVariable, MyVariable2, op1, op2, op3
class TestFunctionGraph: class TestFunctionGraph:
...@@ -60,7 +60,7 @@ class TestFunctionGraph: ...@@ -60,7 +60,7 @@ class TestFunctionGraph:
assert fg.variables == {var1, var2, var3, var4} assert fg.variables == {var1, var2, var3, var4}
assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var1) == [(var3.owner, 0)]
assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var2) == [(var4.owner, 1)]
assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)]
assert fg.get_clients(var4) == [("output", 1)] assert fg.get_clients(var4) == [("output", 1)]
varC = MyConstant("varC") varC = MyConstant("varC")
...@@ -304,7 +304,7 @@ class TestFunctionGraph: ...@@ -304,7 +304,7 @@ class TestFunctionGraph:
# FIXME TODO XXX: This breaks the state of the `FunctionGraph`, # FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
# because it doesn't check for validity of the replacement *first*. # because it doesn't check for validity of the replacement *first*.
fg.replace(var1, var0, verbose=True) fg.replace(var1, var0)
def test_check_integrity(self): def test_check_integrity(self):
...@@ -315,7 +315,7 @@ class TestFunctionGraph: ...@@ -315,7 +315,7 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(Exception, match="The nodes are .*"): with pytest.raises(Exception, match="The following nodes are .*"):
fg.apply_nodes.remove(var5.owner) fg.apply_nodes.remove(var5.owner)
fg.check_integrity() fg.check_integrity()
...@@ -328,7 +328,7 @@ class TestFunctionGraph: ...@@ -328,7 +328,7 @@ class TestFunctionGraph:
fg.add_client(var2, (var5.owner, 1)) fg.add_client(var2, (var5.owner, 1))
with pytest.raises(Exception, match="The variables are.*"): with pytest.raises(Exception, match="The following variables are.*"):
fg.variables.remove(var4) fg.variables.remove(var4)
fg.check_integrity() fg.check_integrity()
...@@ -386,3 +386,300 @@ class TestFunctionGraph: ...@@ -386,3 +386,300 @@ class TestFunctionGraph:
assert var3.owner in fg assert var3.owner in fg
assert var5 in fg assert var5 in fg
assert var5.owner in fg assert var5.owner in fg
def test_remove_node(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
node1_out = op1(var1)
node2_out = op2(var2, node1_out)
node3_out = op3(node2_out)
fg = FunctionGraph([var1, var2], [node3_out], clone=False)
fg.remove_node(node3_out.owner)
fg.check_integrity()
assert not fg.apply_nodes
fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)
fg.remove_node(node3_out.owner)
fg.check_integrity()
assert fg.apply_nodes == {node1_out.owner, node2_out.owner}
fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)
fg.remove_node(node2_out.owner)
fg.check_integrity()
assert not fg.apply_nodes
def test_remove_output(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
node1_out = op1(var1)
node2_out = op2(var2, node1_out)
node3_out = op3(node2_out)
fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)
fg.remove_output(0)
fg.check_integrity()
assert fg.apply_nodes == {node1_out.owner, node2_out.owner, node3_out.owner}
assert fg.inputs == [var1, var2]
assert fg.outputs == [node3_out]
fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False)
fg.remove_output(1)
fg.check_integrity()
assert fg.apply_nodes == {node1_out.owner, node2_out.owner}
assert fg.inputs == [var1, var2]
assert fg.outputs == [node2_out]
fg = FunctionGraph([var1, var2], [node2_out, node3_out, var1], clone=False)
fg.remove_output(2)
fg.check_integrity()
assert fg.apply_nodes == {node1_out.owner, node2_out.owner, node3_out.owner}
assert fg.inputs == [var1, var2]
assert fg.outputs == [node2_out, node3_out]
fg = FunctionGraph([var1, var2], [var1], clone=False)
fg.remove_output(0)
fg.check_integrity()
assert fg.inputs == [var1, var2]
assert fg.outputs == []
def test_remove_output_2(self):
var0 = MyVariable("var0")
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = MyVariable("var3")
var4 = MyVariable("var4")
op1_out = op1(var1, var0)
out0 = op2(op1_out, var2)
out1 = op1(var3, var4)
out1.name = "out1"
out2 = op1(out1, var0)
out2.name = "out2"
out3 = out1
fg = FunctionGraph(
[var0, var1, var2, var3, var4],
[out0, out1, out2, out3],
clone=False,
)
fg.remove_output(1)
fg.check_integrity()
assert fg.outputs == [out0, out2, out3]
fg = FunctionGraph(
[var0, var1, var2, var3, var4],
[out0, out1, out2, var4, var4],
clone=False,
)
fg.remove_output(3)
fg.check_integrity()
assert fg.inputs == [var0, var1, var2, var3, var4]
assert fg.outputs == [out0, out1, out2, var4]
def test_remove_output_3(self):
var0 = MyVariable("var0")
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = MyVariable("var3")
var4 = MyVariable("var4")
var5 = MyVariable("var5")
var6 = MyVariable("var6")
op1_out = op1(var1, var0)
out0 = op2(op1_out, var2)
out1 = op1(var3, var4)
out1.name = "out1"
out2 = op1(op1_out, var5)
out2.name = "out2"
out3 = op1(var3, var6)
out3.name = "out3"
out4 = op1_out
out5 = var3
fg = FunctionGraph(
[var0, var1, var2, var3, var4, var5, var6],
[out0, out1, out2, out3, out4, out5],
clone=False,
)
fg.remove_output(1)
fg.check_integrity()
assert fg.inputs == [var0, var1, var2, var3, var4, var5, var6]
assert fg.outputs == [out0, out2, out3, out4, out5]
assert out1 not in fg.clients
def test_remove_input(self):
var0 = MyVariable("var0")
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = MyVariable("var3")
var4 = MyVariable("var4")
op1_out = op1(var1, var0)
out0 = op2(op1_out, var2)
out1 = op1(var3, var4)
out1.name = "out1"
out2 = op1(out1, var0)
out2.name = "out2"
out3 = out1
fg = FunctionGraph(
[var0, var1, var2, var3, var4],
[out0, out1, out2, out3],
clone=False,
)
fg.remove_input(4)
fg.check_integrity()
assert fg.inputs == [var0, var1, var2, var3]
assert fg.outputs == [out0]
def test_remove_in_and_out(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
op1_out = op1(var2, var1)
op2_out = op2(op1_out, var2)
op3_out = op3(op2_out, var2, var2)
fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False)
# Remove an output
fg.remove_output(1)
fg.check_integrity()
assert fg.outputs == [op1_out]
assert op3_out not in fg.clients
assert not any(
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
)
# Remove an input
fg.remove_input(0)
fg.check_integrity()
assert var1 not in fg.variables
assert fg.inputs == [var2]
assert fg.outputs == []
assert not any(
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
)
def test_remove_duplicates(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
op1_out = op1(var2, var1)
op2_out = op2(op1_out, var2)
op3_out = op3(op2_out, var2, var2)
fg = FunctionGraph([var1, var1, var2], [op1_out, op3_out, op3_out], clone=False)
fg.remove_output(2)
fg.check_integrity()
assert fg.outputs == [op1_out, op3_out]
fg.remove_input(0)
fg.check_integrity()
assert var1 not in fg.variables
assert fg.inputs == [var1, var2]
assert fg.outputs == []
def test_remove_output_empty(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
op1_out = op1(var1)
op3_out = op3(op1_out, var2)
fg = FunctionGraph([var1, var2], [op3_out], clone=False)
fg.remove_output(0)
fg.check_integrity()
assert fg.inputs == [var1, var2]
assert not fg.apply_nodes
assert op1_out not in fg.clients
assert not any(
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
)
assert not any(
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
)
def test_remove_node_multi_out(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
multi_op = MyOp("mop", n_outs=2)
op1_out = op1(var1)
mop_out_1, mop_out_2 = multi_op(op1_out, var2)
op3_out = op3(mop_out_2)
fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False)
fg.remove_node(mop_out_1.owner)
fg.check_integrity()
assert fg.inputs == [var1, var2]
assert fg.outputs == []
assert mop_out_1 not in fg.clients
assert mop_out_2 not in fg.clients
assert mop_out_1 not in fg.variables
assert mop_out_2 not in fg.variables
mop1_out_1, mop1_out_2 = multi_op(var1)
op2_out = op2(mop1_out_1)
op3_out = op3(mop1_out_1, mop1_out_2)
fg = FunctionGraph([var1], [op2_out, op3_out], clone=False)
fg.remove_node(op3_out.owner)
fg.check_integrity()
assert fg.inputs == [var1]
assert fg.outputs == [op2_out]
# If we only want to track "active" variables in the graphs, the
# following would need to be true, as well
# assert mop1_out_2 not in fg.clients
# assert mop1_out_2 not in fg.variables
fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False)
fg.remove_node(op3_out.owner)
fg.check_integrity()
assert fg.inputs == [var1]
assert fg.outputs == [op2_out, mop1_out_2]
assert mop1_out_2 in fg.clients
assert mop1_out_2 in fg.variables
assert mop1_out_2 in fg.outputs
def test_empty(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
fg = FunctionGraph([var1, var2], [], clone=False)
fg.check_integrity()
assert fg.inputs == [var1, var2]
assert fg.outputs == []
assert not fg.variables
assert not fg.apply_nodes
assert fg.clients == {var1: [], var2: []}
...@@ -46,19 +46,20 @@ def MyVariable2(name): ...@@ -46,19 +46,20 @@ def MyVariable2(name):
class MyOp(Op): class MyOp(Op):
def __init__(self, name, dmap=None, x=None): def __init__(self, name, dmap=None, x=None, n_outs=1):
self.name = name self.name = name
if dmap is None: if dmap is None:
dmap = {} dmap = {}
self.destroy_map = dmap self.destroy_map = dmap
self.x = x self.x = x
self.n_outs = n_outs
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(is_variable, inputs)) inputs = list(map(is_variable, inputs))
for input in inputs: for input in inputs:
if not isinstance(input.type, MyType): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
outputs = [MyType()()] outputs = [MyType()() for i in range(self.n_outs)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -71,18 +72,19 @@ class MyOp(Op): ...@@ -71,18 +72,19 @@ class MyOp(Op):
return self.name return self.name
def __eq__(self, other): def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or ( rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x isinstance(other, MyOp)
and self.x is not None
and self.x == other.x
and self.n_outs == other.n_outs
) )
return rval return rval
def __hash__(self): def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None: if self.x is not None:
return hash(self.x) return hash((self.x, self.n_outs))
else: else:
return id(self) return hash((id(self), self.n_outs))
class MyOpCastType2(MyOp): class MyOpCastType2(MyOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论