提交 3e7f00e5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Disable unused Assert logic in MergeFeature

上级 b03bd80f
......@@ -21,7 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import aesara
from aesara.assert_op import Assert, assert_op
from aesara.assert_op import Assert
from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh
from aesara.graph.basic import (
......@@ -563,24 +563,25 @@ class MergeFeature(Feature):
merge_candidates = [c for c, i in clients if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates
# TODO: Deactivated for now as this cause cycle in the graph.
# (There is a second deactivation part below.)
for i in []: # node.inputs:
if i.owner and isinstance(i.owner.op, Assert):
node_has_assert = True
i_clients = fgraph.clients[i.owner.inputs[0]]
assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
for idx in range(len(assert_clients)):
client = assert_clients[idx]
if isinstance(i.owner.op, Assert):
o_clients = fgraph.clients[client.outputs[0]]
for c in o_clients:
if c[0] in self.nodes_seen:
assert_clients.append(c[0])
merge_candidates.extend(assert_clients)
# Put all clients of `CheckAndRaise` inputs (if exist) into
# `merge_candidates`
# TODO: Deactivated for now, because it can create cycles in a
# graph. (There is a second deactivation part below.)
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_has_assert = True
# i_clients = fgraph.clients[i.owner.inputs[0]]
# assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
#
# for idx in range(len(assert_clients)):
# client = assert_clients[idx]
# if isinstance(i.owner.op, CheckAndRaise):
# o_clients = fgraph.clients[client.outputs[0]]
# for c in o_clients:
# if c[0] in self.nodes_seen:
# assert_clients.append(c[0])
#
# merge_candidates.extend(assert_clients)
else:
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
......@@ -598,29 +599,31 @@ class MergeFeature(Feature):
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
# TODO: Deactivated while Assert merging is disabled. (See above and below.)
for i in []: # candidate.inputs:
if i.owner and isinstance(i.owner.op, Assert):
cand_has_assert = True
cand_inputs_assert_removed.append(i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# TODO: Remove this when Assert merging is re-enabled. (See above.)
# Without Assert merging we can still look for identical Asserts,
# so we should not treat Asserts separately for now.
# TODO: Deactivated while `CheckAndRaise` merging is disabled. (See
# above and below.)
# for i in candidate.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# cand_has_assert = True
# cand_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# cand_inputs_assert_removed.append(i)
# TODO: Remove this when `CheckAndRaise` merging is
# re-enabled. (See above.) Without `CheckAndRaise` merging we can
# still look for identical `CheckAndRaise`, so we should not treat
# `CheckAndRaise`s separately for now.
cand_inputs_assert_removed = candidate.inputs
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op, Assert):
node_inputs_assert_removed.append(i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
else:
node_inputs_assert_removed = node.inputs
# if node_has_assert:
# node_inputs_assert_removed = []
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# node_inputs_assert_removed.append(i)
# else:
node_inputs_assert_removed = node.inputs
inputs_match = all(
node_in is cand_in
......@@ -635,7 +638,7 @@ class MergeFeature(Feature):
continue
# replace node with candidate
if not (node_has_assert or cand_has_assert):
if not node_has_assert and not cand_has_assert:
# Schedule transfer of clients from node to candidate
pairs = list(
zip(
......@@ -645,32 +648,32 @@ class MergeFeature(Feature):
)
)
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
elif node_has_assert and not cand_has_assert:
pairs = list(
zip(
candidate.outputs,
node.outputs,
["merge"] * len(node.outputs),
)
)
else:
new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs)
pairs = list(
zip(
node.outputs,
new_node.owner.outputs,
["new_node"] * len(node.outputs),
)
) + list(
zip(
candidate.outputs,
new_node.owner.outputs,
["new_node"] * len(node.outputs),
)
)
# # if the current node has assert input, it should not be
# # replaced with a candidate node which has no assert input
# elif node_has_assert and not cand_has_assert:
# pairs = list(
# zip(
# candidate.outputs,
# node.outputs,
# ["merge"] * len(node.outputs),
# )
# )
# else:
# new_inputs = self.get_merged_assert_input(node, candidate)
# new_node = node.op(*new_inputs)
# pairs = list(
# zip(
# node.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# ) + list(
# zip(
# candidate.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# )
# transfer names
for pair in pairs:
......@@ -689,29 +692,35 @@ class MergeFeature(Feature):
if not node.inputs:
self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i is assert
if node_i.owner and isinstance(node_i.owner.op, Assert):
# node_i is assert, cand_i is assert
if cand_i.owner and isinstance(cand_i.owner.op, Assert):
# Here two assert nodes are merged.
# Step 1. Merge conditions of both assert nodes.
# Step 2. Make the new assert node
node_cond = node_i.owner.inputs[1:]
cand_cond = cand_i.owner.inputs[1:]
new_cond = list(set(node_cond + cand_cond))
new_inputs.append(assert_op(node_i.owner.inputs[0], *new_cond))
# node_i is assert, cand_i is not assert
else:
new_inputs.append(node_i)
else:
# if node_i is not an assert node, append cand_i
new_inputs.append(cand_i)
return new_inputs
# def get_merged_assert_input(self, node, candidate):
# new_inputs = []
# for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i.owner and isinstance(node_i.owner.op, CheckAndRaise):
# if (
# cand_i.owner
# and isinstance(cand_i.owner.op, CheckAndRaise)
# and node_i.owner.op.exc_type == cand_i.owner.op.exc_type
# ):
# # Here two assert nodes are merged.
# # Step 1. Merge conditions of both assert nodes.
# # Step 2. Make the new assert node
# node_cond = node_i.owner.inputs[1:]
# cand_cond = cand_i.owner.inputs[1:]
# new_cond = list(set(node_cond + cand_cond))
# new_raise_op = CheckAndRaise(
# node_i.owner.op.exc_type,
# "; ".join([node_i.owner.op.msg, cand_i.owner.op.msg]),
# )
# new_inputs.append(new_raise_op(*(node_i.owner.inputs[:1] + new_cond)))
#
# # node_i is assert, cand_i is not assert
# else:
# new_inputs.append(node_i)
# else:
# # if node_i is not an assert node, append cand_i
# new_inputs.append(cand_i)
#
# return new_inputs
class MergeOptimizer(GlobalOptimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论