提交 4caaa7ff authored 作者: Jeffrey Enos's avatar Jeffrey Enos 提交者: Thomas Wiecki

Preserve inplace operators in local_ultra_fast_sigmoid

上级 d0a9488a
...@@ -12,6 +12,7 @@ from aesara import scalar as aes ...@@ -12,6 +12,7 @@ from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.scalar.math import Sigmoid
from aesara.tensor.basic import constant from aesara.tensor.basic import constant
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import clip, sigmoid from aesara.tensor.math import clip, sigmoid
...@@ -98,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter("ultra_fast_sigmoid") ...@@ -98,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter("ultra_fast_sigmoid")
# @opt.register_uncanonicalize # @opt.register_uncanonicalize
@local_optimizer([sigmoid]) @local_optimizer(None)
def local_ultra_fast_sigmoid(fgraph, node): def local_ultra_fast_sigmoid(fgraph, node):
""" """
When enabled, change all sigmoid to ultra_fast_sigmoid. When enabled, change all sigmoid to ultra_fast_sigmoid.
...@@ -112,8 +113,13 @@ def local_ultra_fast_sigmoid(fgraph, node): ...@@ -112,8 +113,13 @@ def local_ultra_fast_sigmoid(fgraph, node):
to avoid interacting with them. to avoid interacting with them.
""" """
if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid:
out = ultra_fast_sigmoid(node.inputs[0]) if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sigmoid):
if node.op.inplace_pattern:
out = ultra_fast_sigmoid_inplace(node.inputs[0])
else:
out = ultra_fast_sigmoid(node.inputs[0])
copy_stack_trace(node.outputs[0], out) copy_stack_trace(node.outputs[0], out)
def values_eq_approx_remove_low_prec(a, b): def values_eq_approx_remove_low_prec(a, b):
......
...@@ -7,11 +7,13 @@ from aesara.configdefaults import config ...@@ -7,11 +7,13 @@ from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.inplace import sigmoid_inplace
from aesara.tensor.math import clip, sigmoid from aesara.tensor.math import clip, sigmoid
from aesara.tensor.nnet.sigm import ( from aesara.tensor.nnet.sigm import (
hard_sigmoid, hard_sigmoid,
ultra_fast_scalar_sigmoid, ultra_fast_scalar_sigmoid,
ultra_fast_sigmoid, ultra_fast_sigmoid,
ultra_fast_sigmoid_inplace,
) )
from aesara.tensor.type import matrix from aesara.tensor.type import matrix
from tests.tensor.utils import ( from tests.tensor.utils import (
...@@ -93,6 +95,13 @@ class TestSpecialSigmoidOpts: ...@@ -93,6 +95,13 @@ class TestSpecialSigmoidOpts:
assert topo[0].op == ultra_fast_sigmoid assert topo[0].op == ultra_fast_sigmoid
assert len(topo) == 1 assert len(topo) == 1
s = sigmoid_inplace(x)
f = aesara.function([x], s, mode=mode, accept_inplace=True)
assert check_stack_trace(f, ops_to_check=ultra_fast_sigmoid_inplace)
topo = f.maker.fgraph.toposort()
assert topo[0].op == ultra_fast_sigmoid_inplace
assert len(topo) == 1
@pytest.mark.skipif(config.cxx == "", reason="Needs a C compiler.") @pytest.mark.skipif(config.cxx == "", reason="Needs a C compiler.")
def test_composite_c_code(self): def test_composite_c_code(self):
"""Make sure this `Op`'s `c_code` works within a `Composite`.""" """Make sure this `Op`'s `c_code` works within a `Composite`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论