提交 374b503f authored 作者: James Bergstra's avatar James Bergstra

Modified log1msigm_to_softplus to work for tensor graphs not just scalar

graphs.
上级 1b91bde8
...@@ -100,8 +100,20 @@ logsigm_to_softplus = gof.PatternSub( ...@@ -100,8 +100,20 @@ logsigm_to_softplus = gof.PatternSub(
(tensor.neg, (softplus, (tensor.neg, 'x'))), (tensor.neg, (softplus, (tensor.neg, 'x'))),
allow_multiple_clients = True) allow_multiple_clients = True)
def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = opt.get_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
return False
log1msigm_to_softplus = gof.PatternSub( log1msigm_to_softplus = gof.PatternSub(
(tensor.log, (tensor.sub, tensor.constant([[1.0]]), (sigmoid, 'x'))), (tensor.log,
(tensor.sub,
dict(pattern='y', constraint = _is_1),
(sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')), (tensor.neg, (softplus, 'x')),
allow_multiple_clients = True) allow_multiple_clients = True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论