提交 e6d6f3fc authored 作者: Caglar's avatar Caglar

added scalar_constant_value func. arguments.

上级 486e2946
......@@ -126,7 +126,7 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
def scalarconsts_rest(inputs):
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
......@@ -134,7 +134,7 @@ def scalarconsts_rest(inputs):
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(i, only_process_constants=True)
v = get_scalar_constant_value(i, only_process_constants=only_process_constants)
consts.append(v)
origconsts.append(i)
except NotScalarConstantError:
......@@ -5788,7 +5788,7 @@ def local_log1p(node):
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs)
log_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论