提交 486e2946 authored 作者: Caglar's avatar Caglar

some cosmetic changes.

上级 5bd4c316
...@@ -434,7 +434,7 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') ...@@ -434,7 +434,7 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,') opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,')
def is_1pexp(t): def is_1pexp(t, only_process_constants=True):
""" """
Returns Returns
...@@ -446,8 +446,9 @@ def is_1pexp(t): ...@@ -446,8 +446,9 @@ def is_1pexp(t):
""" """
if t.owner and t.owner.op == tensor.add: if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs) opt.scalarconsts_rest(t.owner.inputs,
# scalar_inputs are potentially dimshuffled and fill'd scalars only_process_constants=only_process_constants)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
maybe_exp = nonconsts[0] maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp: if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
...@@ -521,7 +522,7 @@ def is_mul(var): ...@@ -521,7 +522,7 @@ def is_mul(var):
return None return None
def partition_num_or_denom(r, f): def partition_num_or_denom(r, f, **kwargs):
if r.owner and r.owner.op == tensor.mul: if r.owner and r.owner.op == tensor.mul:
a = r.owner.inputs a = r.owner.inputs
else: else:
...@@ -532,7 +533,7 @@ def partition_num_or_denom(r, f): ...@@ -532,7 +533,7 @@ def partition_num_or_denom(r, f):
neg = False neg = False
rest = [] rest = []
for t in a: for t in a:
f_t = f(t) f_t = f(t, **kwargs)
if f_t is None: if f_t is None:
rest.append(t) rest.append(t)
else: else:
...@@ -956,7 +957,7 @@ def local_inv_1_plus_exp(node): ...@@ -956,7 +957,7 @@ def local_inv_1_plus_exp(node):
inv_arg = node.inputs[0] inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == tensor.add: if inv_arg.owner and inv_arg.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(inv_arg.owner.inputs) opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp: if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论