提交 63a714ac authored 作者: Reyhane Askari's avatar Reyhane Askari

changed the order to avoid the error

上级 eb57fb1d
...@@ -858,20 +858,14 @@ class MergeOptimizer(Optimizer): ...@@ -858,20 +858,14 @@ class MergeOptimizer(Optimizer):
hasattr(c.op, 'destroy_map')]) > 1: hasattr(c.op, 'destroy_map')]) > 1:
continue continue
try: if pairs[0][0].type != pairs[0][1].type:
fgraph.replace_all_validate(pairs, 'MergeOptimizer') res = pairs[0][0].type.convert_variable(pairs[0][1])
except Exception as ex: if res is None or res.type != pairs[0][0].type:
if type(ex) is InconsistencyError:
success = False
nb_fail += 1
fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
elif type(ex) is TypeError:
if (not isinstance(pairs[0][1], pairs[0][0].__class__) or if (not isinstance(pairs[0][1], pairs[0][0].__class__) or
pairs[0][0].dtype != pairs[0][1].dtype or pairs[0][0].dtype != pairs[0][1].dtype or
pairs[0][0].ndim != pairs[0][1].ndim or pairs[0][0].ndim != pairs[0][1].ndim or
pairs[0][0].broadcastable == pairs[0][1].broadcastable or len(pairs) != 1): pairs[0][0].broadcastable == pairs[0][1].broadcastable or len(pairs) != 1):
raise raise TypeError
else: else:
num_broadcastable_dims_0 = sum(pairs[0][0].broadcastable) num_broadcastable_dims_0 = sum(pairs[0][0].broadcastable)
num_broadcastable_dims_1 = sum(pairs[0][1].broadcastable) num_broadcastable_dims_1 = sum(pairs[0][1].broadcastable)
...@@ -883,12 +877,20 @@ class MergeOptimizer(Optimizer): ...@@ -883,12 +877,20 @@ class MergeOptimizer(Optimizer):
for i, j in zip(pairs[0][selected_var_ind].broadcastable, for i, j in zip(pairs[0][selected_var_ind].broadcastable,
pairs[0][1 - selected_var_ind].broadcastable): pairs[0][1 - selected_var_ind].broadcastable):
if not i and j: if not i and j:
raise raise TypeError
new_broadcast_pattern = theano.tensor.patternbroadcast( new_broadcast_pattern = theano.tensor.patternbroadcast(
pairs[0][selected_var_ind], pairs[0][selected_var_ind],
pairs[0][1 - selected_var_ind].broadcastable) pairs[0][1 - selected_var_ind].broadcastable)
new_pairs = [(new_broadcast_pattern, pairs[0][1 - selected_var_ind])] pairs = [(new_broadcast_pattern, pairs[0][1 - selected_var_ind])]
fgraph.replace_all_validate(new_pairs, 'MergeOptimizer')
try:
fgraph.replace_all_validate(pairs, 'MergeOptimizer')
except Exception as ex:
if type(ex) is InconsistencyError:
success = False
nb_fail += 1
fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success: if success:
nb_merged += len(pairs) nb_merged += len(pairs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论