提交 543e0d11 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added log(sigmoid) -> softplus optimizations

上级 19487cd3
...@@ -348,7 +348,7 @@ class Constant(Value): ...@@ -348,7 +348,7 @@ class Constant(Value):
Value.__init__(self, type, data, name) Value.__init__(self, type, data, name)
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id # this does what __eq__ should do, but Result and Apply should always be hashable by id
return type(other) == type(self) and self.signature() == other.signature() return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self): def signature(self):
return (self.type, self.data) return (self.type, self.data)
def __str__(self): def __str__(self):
......
...@@ -9,6 +9,7 @@ from ..printing import pprint ...@@ -9,6 +9,7 @@ from ..printing import pprint
import basic as tensor import basic as tensor
import elemwise import elemwise
import numpy import numpy
import opt
############ ############
# #
...@@ -733,3 +734,16 @@ class solve(gof.Op): ...@@ -733,3 +734,16 @@ class solve(gof.Op):
raise NotImplementedError() raise NotImplementedError()
logsigm_to_softplus = gof.PatternSub(
(tensor.log, (sigmoid, 'x')),
(tensor.neg, (softplus, (tensor.neg, 'x'))),
allow_multiple_clients = True)
log1msigm_to_softplus = gof.PatternSub(
(tensor.log, (tensor.sub, tensor.constant([[1.0]]), (sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')),
allow_multiple_clients = True)
opt.register_specialize(logsigm_to_softplus, name = 'logsigm_to_softplus')
opt.register_specialize(log1msigm_to_softplus, name = 'log1msigm_to_softplus')
...@@ -98,10 +98,12 @@ compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run') ...@@ -98,10 +98,12 @@ compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
compile.optdb['canonicalize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags) name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags) name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags)
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论