提交 106b357f authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for fgraph replacement order

上级 85478a92
...@@ -860,24 +860,12 @@ class MergeOptimizer(Optimizer): ...@@ -860,24 +860,12 @@ class MergeOptimizer(Optimizer):
if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type: if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
res = pairs[0][0].type.convert_variable(pairs[0][1]) res = pairs[0][0].type.convert_variable(pairs[0][1])
if res is None:
num_broadcastable_dims_0 = sum(pairs[0][0].broadcastable) # Since the fgraph.replace only checks the convert_variable
num_broadcastable_dims_1 = sum(pairs[0][1].broadcastable) # in one way, we change the order in the case that
# select the variable to be removed from the fgraph # convert_variable will not be successful.
if num_broadcastable_dims_0 <= num_broadcastable_dims_1: if not res:
selected_var_ind = 1 pairs = [(pairs[0][1], pairs[0][0])]
else:
selected_var_ind = 0
for i, j in zip(pairs[0][selected_var_ind].broadcastable,
pairs[0][1 - selected_var_ind].broadcastable):
if not i and j:
raise TypeError
new_broadcast_pattern = theano.tensor.patternbroadcast(
pairs[0][selected_var_ind],
pairs[0][1 - selected_var_ind].broadcastable)
res = new_broadcast_pattern.type.convert_variable(pairs[0][1 - selected_var_ind])
if res:
pairs = [(new_broadcast_pattern, pairs[0][1 - selected_var_ind])]
try: try:
fgraph.replace_all_validate(pairs, 'MergeOptimizer') fgraph.replace_all_validate(pairs, 'MergeOptimizer')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论