提交 01a24b83 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix scope issue in UltraFastScalarSigmoid.c_code

上级 7d07260b
......@@ -58,7 +58,7 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp):
dtype = node.outputs[0].type.dtype_specs()[1]
return (
"""
"""{
%(dtype)s x = 0.5 * %(x)s;
// The if is a tanh approximate.
if(x>=0) {
......@@ -74,10 +74,14 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp):
//%(z)s = 0.5*(ultrafasttanh(0.5*x)+1.);
%(z)s = 0.5*(%(z)s+1.);
"""
}"""
% locals()
)
@staticmethod
def c_code_cache_version():
return (5,)
ultra_fast_scalar_sigmoid = UltraFastScalarSigmoid(
aes.upgrade_to_float, name="ultra_fast_scalar_sigmoid"
......
import numpy as np
import pytest
import aesara
from aesara.compile.mode import get_default_mode, get_mode
from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace
from aesara.scalar.basic import Composite
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import clip, sigmoid
from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid
from aesara.tensor.nnet.sigm import (
hard_sigmoid,
ultra_fast_scalar_sigmoid,
ultra_fast_sigmoid,
)
from aesara.tensor.type import matrix
from tests.tensor.utils import (
_good_broadcast_unary_normal_no_complex,
......@@ -59,9 +67,9 @@ class TestSpecialSigmoidOpts:
excluding = []
m = config.mode
if m == "FAST_COMPILE":
mode = aesara.compile.mode.get_mode("FAST_RUN")
mode = get_mode("FAST_RUN")
else:
mode = aesara.compile.mode.get_default_mode()
mode = get_default_mode()
if excluding:
return mode.excluding(*excluding)
else:
......@@ -84,7 +92,21 @@ class TestSpecialSigmoidOpts:
topo = f.maker.fgraph.toposort()
assert topo[0].op == ultra_fast_sigmoid
assert len(topo) == 1
f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]])
@pytest.mark.skipif(config.cxx == "", reason="Needs a C compiler.")
def test_composite_c_code(self):
"""Make sure this `Op`'s `c_code` works within a `Composite`."""
x = matrix("x")
mode = get_mode("FAST_RUN").including("local_ultra_fast_sigmoid")
f = aesara.function([x], sigmoid(x) + sigmoid(x + 1), mode=mode)
topo = f.maker.fgraph.toposort()
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, Composite)
assert ultra_fast_scalar_sigmoid in set(
node.op for node in topo[0].op.scalar_op.fgraph.toposort()
)
assert len(topo) == 1
def test_local_hard_sigmoid(self):
x = matrix("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论