提交 6e91c9d4 authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 b69d2aae
...@@ -1147,18 +1147,18 @@ def apply_rebroadcast_opt(rval): ...@@ -1147,18 +1147,18 @@ def apply_rebroadcast_opt(rval):
changed = True changed = True
while changed and rval.owner: while changed and rval.owner:
changed = False changed = False
rval2 = theano.tensor.opt.local_useless_rebroadcast.transform(rval.owner) rval2 = theano.tensor.opt.local_useless_rebroadcast.transform(rval.owner)
if rval2:
assert len(rval2)==1
rval = rval2[0]
changed = True
if rval.owner:
rval2 = theano.tensor.opt.local_rebroadcast_lift.transform(rval.owner)
if rval2: if rval2:
assert len(rval2)==1 assert len(rval2)==1
rval = rval2[0] rval = rval2[0]
changed = True changed = True
if rval.owner:
rval2 = theano.tensor.opt.local_rebroadcast_lift.transform(rval.owner)
if rval2:
assert len(rval2)==1
rval = rval2[0]
changed = True
return rval return rval
...@@ -1216,7 +1216,7 @@ def local_mul_switch_sink(node): ...@@ -1216,7 +1216,7 @@ def local_mul_switch_sink(node):
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
except TypeError: except TypeError:
pass pass
try: try:
if get_constant_value(switch.inputs[2]) == 0.: if get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:] listmul = node.inputs[:idx] + node.inputs[idx+1:]
...@@ -2398,9 +2398,9 @@ def local_log_add(node): ...@@ -2398,9 +2398,9 @@ def local_log_add(node):
def add_calculate(num, denum, aslist = False, out_type=None): def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar #TODO: make sure that this function and mul_calculate are similar
if out_type is None: if out_type is None:
zero = 0.0 zero = 0.0
else: else:
zero = theano._asarray(0, dtype=out_type.dtype) zero = theano._asarray(0, dtype=out_type.dtype)
#zero = 0.0 if out_type is None else theano._asarray(0, dtype=out_type.dtype) #zero = 0.0 if out_type is None else theano._asarray(0, dtype=out_type.dtype)
v = reduce(N.add, num, zero) - reduce(N.add, denum, zero) v = reduce(N.add, num, zero) - reduce(N.add, denum, zero)
if aslist: if aslist:
...@@ -2856,7 +2856,7 @@ def local_grad_log_erfc_neg(node): ...@@ -2856,7 +2856,7 @@ def local_grad_log_erfc_neg(node):
#The constant is valid. Must check that the #The constant is valid. Must check that the
elif erfc_x is not x: elif erfc_x is not x:
return False return False
else: else:
return False return False
...@@ -3098,5 +3098,3 @@ if config.tensor.local_elemwise_fusion: ...@@ -3098,5 +3098,3 @@ if config.tensor.local_elemwise_fusion:
else: else:
_logger.debug("not enabling optimization fusion elemwise in fast_run") _logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion', FusionOptimizer(local_elemwise_fusion), 71.00, 'fusion', 'local_elemwise_fusion') compile.optdb.register('elemwise_fusion', FusionOptimizer(local_elemwise_fusion), 71.00, 'fusion', 'local_elemwise_fusion')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论