提交 85478a92 authored 作者: Reyhane Askari's avatar Reyhane Askari

simplified the check for TypeError

上级 2ed066dc
...@@ -858,29 +858,25 @@ class MergeOptimizer(Optimizer): ...@@ -858,29 +858,25 @@ class MergeOptimizer(Optimizer):
hasattr(c.op, 'destroy_map')]) > 1: hasattr(c.op, 'destroy_map')]) > 1:
continue continue
if pairs[0][0].type != pairs[0][1].type: if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
res = pairs[0][0].type.convert_variable(pairs[0][1]) res = pairs[0][0].type.convert_variable(pairs[0][1])
if res is None or res.type != pairs[0][0].type: if res is None:
if (not isinstance(pairs[0][1], pairs[0][0].__class__) or num_broadcastable_dims_0 = sum(pairs[0][0].broadcastable)
pairs[0][0].dtype != pairs[0][1].dtype or num_broadcastable_dims_1 = sum(pairs[0][1].broadcastable)
pairs[0][0].ndim != pairs[0][1].ndim or # select the variable to be removed from the fgraph
pairs[0][0].broadcastable == pairs[0][1].broadcastable or len(pairs) != 1): if num_broadcastable_dims_0 <= num_broadcastable_dims_1:
raise TypeError selected_var_ind = 1
else: else:
num_broadcastable_dims_0 = sum(pairs[0][0].broadcastable) selected_var_ind = 0
num_broadcastable_dims_1 = sum(pairs[0][1].broadcastable) for i, j in zip(pairs[0][selected_var_ind].broadcastable,
# select the variable to be removed from the fgraph pairs[0][1 - selected_var_ind].broadcastable):
if num_broadcastable_dims_0 <= num_broadcastable_dims_1: if not i and j:
selected_var_ind = 1 raise TypeError
else: new_broadcast_pattern = theano.tensor.patternbroadcast(
selected_var_ind = 0 pairs[0][selected_var_ind],
for i, j in zip(pairs[0][selected_var_ind].broadcastable, pairs[0][1 - selected_var_ind].broadcastable)
pairs[0][1 - selected_var_ind].broadcastable): res = new_broadcast_pattern.type.convert_variable(pairs[0][1 - selected_var_ind])
if not i and j: if res:
raise TypeError
new_broadcast_pattern = theano.tensor.patternbroadcast(
pairs[0][selected_var_ind],
pairs[0][1 - selected_var_ind].broadcastable)
pairs = [(new_broadcast_pattern, pairs[0][1 - selected_var_ind])] pairs = [(new_broadcast_pattern, pairs[0][1 - selected_var_ind])]
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论