提交 79112666 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused CheckAndRaise code and minor refactoring to MergeOptimizer

上级 48c9ef88
......@@ -529,9 +529,9 @@ class MergeFeature(Feature):
self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r, reason):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
# If inputs to a node change, it's not guaranteed that the node is
# distinct from the other nodes in `self.nodes_seen`.
self.nodes_seen.discard(node)
self.process_node(fgraph, node)
......@@ -580,52 +580,30 @@ class MergeFeature(Feature):
self.seen_constants.add(id(c))
def process_node(self, fgraph, node):
"""Check if a `node` can be merged, and queue that replacement."""
r"""Check if a `node` can be merged, and queue that replacement.
When `node` is changed we check for other nodes (via the clients map)
that depend on the same inputs. If any of those other nodes have the
same inputs and `Op` as `node`, they are queued to be merged.
"""
if node in self.nodes_seen:
return
node_has_assert = False
# These asserts ensure that the fgraph has set the clients field
# properly.
# The clients should at least contain `node` itself!
if node.inputs:
# Take the smallest clients list. Some ops like elemwise
# have optimization that put constant as the first inputs.
# As constant have in general more clients than other type of nodes
# using always inputs[0] make us look at more nodes.
# Always pick the smallest clints list between inputs 0
# and -1 speed up optimization.
a_clients = fgraph.clients[node.inputs[0]]
b_clients = fgraph.clients[node.inputs[-1]]
if len(a_clients) < len(b_clients):
clients = a_clients
else:
clients = b_clients
# We use the smallest clients list. Some `Op`s like `Elemwise`
# have optimizations that put constants as the first inputs. Since
# constants generally have more clients than other types of nodes,
# using `node.inputs[0]` will make us look at more nodes on
# average, so by picking the smallest clients list, we might speed
# things up?
clients = sorted(
(fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x)
)[0]
assert len(clients) > 0
merge_candidates = [c for c, i in clients if c in self.nodes_seen]
# Put all clients of `CheckAndRaise` inputs (if exist) into
# `merge_candidates`
# TODO: Deactivated for now, because it can create cycles in a
# graph. (There is a second deactivation part below.)
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_has_assert = True
# i_clients = fgraph.clients[i.owner.inputs[0]]
# assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
#
# for idx in range(len(assert_clients)):
# client = assert_clients[idx]
# if isinstance(i.owner.op, CheckAndRaise):
# o_clients = fgraph.clients[client.outputs[0]]
# for c in o_clients:
# if c[0] in self.nodes_seen:
# assert_clients.append(c[0])
#
# merge_candidates.extend(assert_clients)
else:
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
......@@ -639,41 +617,9 @@ class MergeFeature(Feature):
if len(node.inputs) != len(candidate.inputs):
continue
cand_has_assert = False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
# TODO: Deactivated while `CheckAndRaise` merging is disabled. (See
# above and below.)
# for i in candidate.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# cand_has_assert = True
# cand_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# cand_inputs_assert_removed.append(i)
# TODO: Remove this when `CheckAndRaise` merging is
# re-enabled. (See above.) Without `CheckAndRaise` merging we can
# still look for identical `CheckAndRaise`, so we should not treat
# `CheckAndRaise`s separately for now.
cand_inputs_assert_removed = candidate.inputs
# Get input list of the node with assert removed
# if node_has_assert:
# node_inputs_assert_removed = []
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# node_inputs_assert_removed.append(i)
# else:
node_inputs_assert_removed = node.inputs
inputs_match = all(
node_in is cand_in
for node_in, cand_in in zip(
node_inputs_assert_removed, cand_inputs_assert_removed
)
for node_in, cand_in in zip(node.inputs, candidate.inputs)
)
if inputs_match and node.op == candidate.op:
......@@ -681,51 +627,14 @@ class MergeFeature(Feature):
# They were already tried, and there was an error
continue
# replace node with candidate
if not node_has_assert and not cand_has_assert:
# Schedule transfer of clients from node to candidate
pairs = list(
zip(
node.outputs,
candidate.outputs,
["merge"] * len(node.outputs),
)
# Schedule transfer of clients from node to candidate
pairs = list(
zip(
node.outputs,
candidate.outputs,
["merge"] * len(node.outputs),
)
# # if the current node has assert input, it should not be
# # replaced with a candidate node which has no assert input
# elif node_has_assert and not cand_has_assert:
# pairs = list(
# zip(
# candidate.outputs,
# node.outputs,
# ["merge"] * len(node.outputs),
# )
# )
# else:
# new_inputs = self.get_merged_assert_input(node, candidate)
# new_node = node.op(*new_inputs)
# pairs = list(
# zip(
# node.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# ) + list(
# zip(
# candidate.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# )
# transfer names
for pair in pairs:
node_output, cand_output = pair[:2]
# clobber old name with new one
# it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
)
replacement_candidates.append(pairs)
......@@ -736,36 +645,6 @@ class MergeFeature(Feature):
if not node.inputs:
self.noinput_nodes.add(node)
# def get_merged_assert_input(self, node, candidate):
# new_inputs = []
# for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i.owner and isinstance(node_i.owner.op, CheckAndRaise):
# if (
# cand_i.owner
# and isinstance(cand_i.owner.op, CheckAndRaise)
# and node_i.owner.op.exc_type == cand_i.owner.op.exc_type
# ):
# # Here two assert nodes are merged.
# # Step 1. Merge conditions of both assert nodes.
# # Step 2. Make the new assert node
# node_cond = node_i.owner.inputs[1:]
# cand_cond = cand_i.owner.inputs[1:]
# new_cond = list(set(node_cond + cand_cond))
# new_raise_op = CheckAndRaise(
# node_i.owner.op.exc_type,
# "; ".join([node_i.owner.op.msg, cand_i.owner.op.msg]),
# )
# new_inputs.append(new_raise_op(*(node_i.owner.inputs[:1] + new_cond)))
#
# # node_i is assert, cand_i is not assert
# else:
# new_inputs.append(node_i)
# else:
# # if node_i is not an assert node, append cand_i
# new_inputs.append(cand_i)
#
# return new_inputs
class MergeOptimizer(GlobalOptimizer):
r"""Merges parts of the graph that are identical and redundant.
......@@ -786,10 +665,6 @@ class MergeOptimizer(GlobalOptimizer):
fgraph.attach_feature(MergeFeature())
def apply(self, fgraph):
from aesara.raise_op import CheckAndRaise
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled
nb_fail = 0
t0 = time.time()
......@@ -804,51 +679,32 @@ class MergeOptimizer(GlobalOptimizer):
pairs_list = sched.pop()
success = True
for pairs_ in pairs_list:
# We must check again the equivalence, as the graph
# could've changed. If so, doing the replacement can
# introduce a node that depends on itself. Doing the
# full check of such cycles every time is very time
# consuming. I think this double check is faster than
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph doesn't contain
# destroyers.
var, candidate, merge_mode = pairs_[0]
# We must check again the equivalence, as the graph could've
# changed. If so, doing the replacement can introduce a node
# that depends on itself. Doing the full check of such cycles
# every time is very time consuming. I think this double check
# is faster than doing the full cycle check. The full cycle
# check is skipped by `Validator.validate` if the graph doesn't
# contain destroyers.
var, candidate_var, merge_mode = pairs_[0]
if merge_mode == "new_node" and var in fgraph.variables:
pass
elif var not in fgraph.variables or candidate not in fgraph.variables:
elif (
var not in fgraph.variables or candidate_var not in fgraph.variables
):
continue
# Keep len(item) == 2 for item in pairs
pairs = [pair[:2] for pair in pairs_]
if var.owner and candidate.owner:
node = var.owner
candidate = candidate.owner
# Get input list of the candidate node with assert
# nodes removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op, CheckAndRaise):
cand_inputs_assert_removed.append(i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op, CheckAndRaise):
node_inputs_assert_removed.append(i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
if var.owner and candidate_var.owner:
if merge_mode == "new_node":
inputs_match = True
else:
inputs_match = all(
node_in is cand_in
for node_in, cand_in in zip(
node_inputs_assert_removed, cand_inputs_assert_removed
var.owner.inputs, candidate_var.owner.inputs
)
)
......@@ -862,15 +718,10 @@ class MergeOptimizer(GlobalOptimizer):
clients = (
fgraph.clients[pairs[0][0]] + fgraph.clients[pairs[0][1]]
)
if (
sum(
[
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and c.op.destroy_map
]
)
> 1
if any(
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and c.op.destroy_map
):
continue
......@@ -884,10 +735,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs = [(pairs[0][1], pairs[0][0])]
try:
# If all Constants, no need to call validate.
# Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them.
if all(isinstance(old, Constant) for old, new in pairs):
# If they're all `Constant`s, there's no need to call validate.
if all(isinstance(old, Constant) for old, _ in pairs):
fgraph.replace_all(pairs, reason="MergeOptimizer")
else:
fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
......@@ -902,7 +751,6 @@ class MergeOptimizer(GlobalOptimizer):
nb_merged += len(pairs)
if isinstance(pairs[0][0], Constant):
nb_constant += 1
# print pairs, pairs[0][0].type
break
if fgraph.profile:
......@@ -920,8 +768,9 @@ class MergeOptimizer(GlobalOptimizer):
validate_time = None
callback_time = None
callbacks_time = {}
# clear blacklist
fgraph.merge_feature.blacklist = []
return (
nb_fail,
time.time() - t0,
......
......@@ -324,7 +324,7 @@ class TestMergeOptimizer:
@pytest.mark.skip(reason="This was disabled for some unknown reason")
def test_one_assert_merge(self):
# Merge two nodes, one has assert, the other not.
"""Merge two nodes, one has assert, the other not."""
x1 = matrix("x1")
x2 = matrix("x2")
e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
......@@ -342,8 +342,7 @@ class TestMergeOptimizer:
assert add_inputs[0] is add_inputs[1]
def test_both_assert_merge_identical(self):
# Merge two nodes, both have assert on the same node
# with the same conditions.
"""Merge two nodes, both have `Assert`s on the same node with the same conditions."""
x1 = matrix("x1")
x2 = matrix("x2")
e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot(
......@@ -434,7 +433,7 @@ class TestMergeOptimizer:
assert add_inputs[0] is add_inputs[1]
def test_merge_noinput(self):
# Check that identical Apply nodes without inputs will be merged
"""Check that identical Apply nodes without inputs will be merged."""
x = NoInputOp(param=0)()
y = NoInputOp(param=0)()
z = NoInputOp(param=1)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论