提交 3e7f00e5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Disable unused Assert logic in MergeFeature

上级 b03bd80f
...@@ -21,7 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -21,7 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import aesara import aesara
from aesara.assert_op import Assert, assert_op from aesara.assert_op import Assert
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh from aesara.graph import destroyhandler as dh
from aesara.graph.basic import ( from aesara.graph.basic import (
...@@ -563,24 +563,25 @@ class MergeFeature(Feature): ...@@ -563,24 +563,25 @@ class MergeFeature(Feature):
merge_candidates = [c for c, i in clients if c in self.nodes_seen] merge_candidates = [c for c, i in clients if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates # Put all clients of `CheckAndRaise` inputs (if exist) into
# TODO: Deactivated for now as this cause cycle in the graph. # `merge_candidates`
# (There is a second deactivation part below.) # TODO: Deactivated for now, because it can create cycles in a
for i in []: # node.inputs: # graph. (There is a second deactivation part below.)
if i.owner and isinstance(i.owner.op, Assert): # for i in node.inputs:
node_has_assert = True # if i.owner and isinstance(i.owner.op, CheckAndRaise):
i_clients = fgraph.clients[i.owner.inputs[0]] # node_has_assert = True
assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen] # i_clients = fgraph.clients[i.owner.inputs[0]]
# assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
for idx in range(len(assert_clients)): #
client = assert_clients[idx] # for idx in range(len(assert_clients)):
if isinstance(i.owner.op, Assert): # client = assert_clients[idx]
o_clients = fgraph.clients[client.outputs[0]] # if isinstance(i.owner.op, CheckAndRaise):
for c in o_clients: # o_clients = fgraph.clients[client.outputs[0]]
if c[0] in self.nodes_seen: # for c in o_clients:
assert_clients.append(c[0]) # if c[0] in self.nodes_seen:
# assert_clients.append(c[0])
merge_candidates.extend(assert_clients) #
# merge_candidates.extend(assert_clients)
else: else:
# If two nodes have no input, but perform the same operation, # If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them. # they are not always constant-folded, so we want to merge them.
...@@ -598,28 +599,30 @@ class MergeFeature(Feature): ...@@ -598,28 +599,30 @@ class MergeFeature(Feature):
# Get input list of the candidate with assert removed # Get input list of the candidate with assert removed
cand_inputs_assert_removed = [] cand_inputs_assert_removed = []
# TODO: Deactivated while Assert merging is disabled. (See above and below.) # TODO: Deactivated while `CheckAndRaise` merging is disabled. (See
for i in []: # candidate.inputs: # above and below.)
if i.owner and isinstance(i.owner.op, Assert): # for i in candidate.inputs:
cand_has_assert = True # if i.owner and isinstance(i.owner.op, CheckAndRaise):
cand_inputs_assert_removed.append(i.owner.inputs[0]) # cand_has_assert = True
else: # cand_inputs_assert_removed.append(i.owner.inputs[0])
cand_inputs_assert_removed.append(i) # else:
# cand_inputs_assert_removed.append(i)
# TODO: Remove this when Assert merging is re-enabled. (See above.)
# Without Assert merging we can still look for identical Asserts, # TODO: Remove this when `CheckAndRaise` merging is
# so we should not treat Asserts separately for now. # re-enabled. (See above.) Without `CheckAndRaise` merging we can
# still look for identical `CheckAndRaise`, so we should not treat
# `CheckAndRaise`s separately for now.
cand_inputs_assert_removed = candidate.inputs cand_inputs_assert_removed = candidate.inputs
# Get input list of the node with assert removed # Get input list of the node with assert removed
if node_has_assert: # if node_has_assert:
node_inputs_assert_removed = [] # node_inputs_assert_removed = []
for i in node.inputs: # for i in node.inputs:
if i.owner and isinstance(i.owner.op, Assert): # if i.owner and isinstance(i.owner.op, CheckAndRaise):
node_inputs_assert_removed.append(i.owner.inputs[0]) # node_inputs_assert_removed.append(i.owner.inputs[0])
else: # else:
node_inputs_assert_removed.append(i) # node_inputs_assert_removed.append(i)
else: # else:
node_inputs_assert_removed = node.inputs node_inputs_assert_removed = node.inputs
inputs_match = all( inputs_match = all(
...@@ -635,7 +638,7 @@ class MergeFeature(Feature): ...@@ -635,7 +638,7 @@ class MergeFeature(Feature):
continue continue
# replace node with candidate # replace node with candidate
if not (node_has_assert or cand_has_assert): if not node_has_assert and not cand_has_assert:
# Schedule transfer of clients from node to candidate # Schedule transfer of clients from node to candidate
pairs = list( pairs = list(
zip( zip(
...@@ -645,32 +648,32 @@ class MergeFeature(Feature): ...@@ -645,32 +648,32 @@ class MergeFeature(Feature):
) )
) )
# if the current node has assert input, it should not be # # if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input # # replaced with a candidate node which has no assert input
elif node_has_assert and not cand_has_assert: # elif node_has_assert and not cand_has_assert:
pairs = list( # pairs = list(
zip( # zip(
candidate.outputs, # candidate.outputs,
node.outputs, # node.outputs,
["merge"] * len(node.outputs), # ["merge"] * len(node.outputs),
) # )
) # )
else: # else:
new_inputs = self.get_merged_assert_input(node, candidate) # new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs) # new_node = node.op(*new_inputs)
pairs = list( # pairs = list(
zip( # zip(
node.outputs, # node.outputs,
new_node.owner.outputs, # new_node.owner.outputs,
["new_node"] * len(node.outputs), # ["new_node"] * len(node.outputs),
) # )
) + list( # ) + list(
zip( # zip(
candidate.outputs, # candidate.outputs,
new_node.owner.outputs, # new_node.owner.outputs,
["new_node"] * len(node.outputs), # ["new_node"] * len(node.outputs),
) # )
) # )
# transfer names # transfer names
for pair in pairs: for pair in pairs:
...@@ -689,29 +692,35 @@ class MergeFeature(Feature): ...@@ -689,29 +692,35 @@ class MergeFeature(Feature):
if not node.inputs: if not node.inputs:
self.noinput_nodes.add(node) self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate): # def get_merged_assert_input(self, node, candidate):
new_inputs = [] # new_inputs = []
for node_i, cand_i in zip(node.inputs, candidate.inputs): # for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i is assert # if node_i.owner and isinstance(node_i.owner.op, CheckAndRaise):
if node_i.owner and isinstance(node_i.owner.op, Assert): # if (
# node_i is assert, cand_i is assert # cand_i.owner
if cand_i.owner and isinstance(cand_i.owner.op, Assert): # and isinstance(cand_i.owner.op, CheckAndRaise)
# Here two assert nodes are merged. # and node_i.owner.op.exc_type == cand_i.owner.op.exc_type
# Step 1. Merge conditions of both assert nodes. # ):
# Step 2. Make the new assert node # # Here two assert nodes are merged.
node_cond = node_i.owner.inputs[1:] # # Step 1. Merge conditions of both assert nodes.
cand_cond = cand_i.owner.inputs[1:] # # Step 2. Make the new assert node
new_cond = list(set(node_cond + cand_cond)) # node_cond = node_i.owner.inputs[1:]
new_inputs.append(assert_op(node_i.owner.inputs[0], *new_cond)) # cand_cond = cand_i.owner.inputs[1:]
# new_cond = list(set(node_cond + cand_cond))
# node_i is assert, cand_i is not assert # new_raise_op = CheckAndRaise(
else: # node_i.owner.op.exc_type,
new_inputs.append(node_i) # "; ".join([node_i.owner.op.msg, cand_i.owner.op.msg]),
else: # )
# if node_i is not an assert node, append cand_i # new_inputs.append(new_raise_op(*(node_i.owner.inputs[:1] + new_cond)))
new_inputs.append(cand_i) #
# # node_i is assert, cand_i is not assert
return new_inputs # else:
# new_inputs.append(node_i)
# else:
# # if node_i is not an assert node, append cand_i
# new_inputs.append(cand_i)
#
# return new_inputs
class MergeOptimizer(GlobalOptimizer): class MergeOptimizer(GlobalOptimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论