提交 4caaa7ff authored 作者: Jeffrey Enos's avatar Jeffrey Enos 提交者: Thomas Wiecki

Preserve inplace operators in local_ultra_fast_sigmoid

上级 d0a9488a
......@@ -12,6 +12,7 @@ from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.scalar.math import Sigmoid
from aesara.tensor.basic import constant
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import clip, sigmoid
......@@ -98,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter("ultra_fast_sigmoid")
# @opt.register_uncanonicalize
@local_optimizer([sigmoid])
@local_optimizer(None)
def local_ultra_fast_sigmoid(fgraph, node):
"""
When enabled, change all sigmoid to ultra_fast_sigmoid.
......@@ -112,8 +113,13 @@ def local_ultra_fast_sigmoid(fgraph, node):
to avoid interacting with them.
"""
if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid:
out = ultra_fast_sigmoid(node.inputs[0])
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sigmoid):
if node.op.inplace_pattern:
out = ultra_fast_sigmoid_inplace(node.inputs[0])
else:
out = ultra_fast_sigmoid(node.inputs[0])
copy_stack_trace(node.outputs[0], out)
def values_eq_approx_remove_low_prec(a, b):
......
......@@ -7,11 +7,13 @@ from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace
from aesara.scalar.basic import Composite
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.inplace import sigmoid_inplace
from aesara.tensor.math import clip, sigmoid
from aesara.tensor.nnet.sigm import (
hard_sigmoid,
ultra_fast_scalar_sigmoid,
ultra_fast_sigmoid,
ultra_fast_sigmoid_inplace,
)
from aesara.tensor.type import matrix
from tests.tensor.utils import (
......@@ -93,6 +95,13 @@ class TestSpecialSigmoidOpts:
assert topo[0].op == ultra_fast_sigmoid
assert len(topo) == 1
s = sigmoid_inplace(x)
f = aesara.function([x], s, mode=mode, accept_inplace=True)
assert check_stack_trace(f, ops_to_check=ultra_fast_sigmoid_inplace)
topo = f.maker.fgraph.toposort()
assert topo[0].op == ultra_fast_sigmoid_inplace
assert len(topo) == 1
@pytest.mark.skipif(config.cxx == "", reason="Needs a C compiler.")
def test_composite_c_code(self):
"""Make sure this `Op`'s `c_code` works within a `Composite`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论