提交 01a24b83 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix scope issue in UltraFastScalarSigmoid.c_code

上级 7d07260b
...@@ -58,7 +58,7 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp): ...@@ -58,7 +58,7 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp):
dtype = node.outputs[0].type.dtype_specs()[1] dtype = node.outputs[0].type.dtype_specs()[1]
return ( return (
""" """{
%(dtype)s x = 0.5 * %(x)s; %(dtype)s x = 0.5 * %(x)s;
// The if is a tanh approximate. // The if is a tanh approximate.
if(x>=0) { if(x>=0) {
...@@ -74,10 +74,14 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp): ...@@ -74,10 +74,14 @@ class UltraFastScalarSigmoid(aes.UnaryScalarOp):
//%(z)s = 0.5*(ultrafasttanh(0.5*x)+1.); //%(z)s = 0.5*(ultrafasttanh(0.5*x)+1.);
%(z)s = 0.5*(%(z)s+1.); %(z)s = 0.5*(%(z)s+1.);
""" }"""
% locals() % locals()
) )
@staticmethod
def c_code_cache_version():
return (5,)
ultra_fast_scalar_sigmoid = UltraFastScalarSigmoid( ultra_fast_scalar_sigmoid = UltraFastScalarSigmoid(
aes.upgrade_to_float, name="ultra_fast_scalar_sigmoid" aes.upgrade_to_float, name="ultra_fast_scalar_sigmoid"
......
import numpy as np import numpy as np
import pytest
import aesara import aesara
from aesara.compile.mode import get_default_mode, get_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.scalar.basic import Composite
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import clip, sigmoid from aesara.tensor.math import clip, sigmoid
from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid from aesara.tensor.nnet.sigm import (
hard_sigmoid,
ultra_fast_scalar_sigmoid,
ultra_fast_sigmoid,
)
from aesara.tensor.type import matrix from aesara.tensor.type import matrix
from tests.tensor.utils import ( from tests.tensor.utils import (
_good_broadcast_unary_normal_no_complex, _good_broadcast_unary_normal_no_complex,
...@@ -59,9 +67,9 @@ class TestSpecialSigmoidOpts: ...@@ -59,9 +67,9 @@ class TestSpecialSigmoidOpts:
excluding = [] excluding = []
m = config.mode m = config.mode
if m == "FAST_COMPILE": if m == "FAST_COMPILE":
mode = aesara.compile.mode.get_mode("FAST_RUN") mode = get_mode("FAST_RUN")
else: else:
mode = aesara.compile.mode.get_default_mode() mode = get_default_mode()
if excluding: if excluding:
return mode.excluding(*excluding) return mode.excluding(*excluding)
else: else:
...@@ -84,7 +92,21 @@ class TestSpecialSigmoidOpts: ...@@ -84,7 +92,21 @@ class TestSpecialSigmoidOpts:
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert topo[0].op == ultra_fast_sigmoid assert topo[0].op == ultra_fast_sigmoid
assert len(topo) == 1 assert len(topo) == 1
f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]])
@pytest.mark.skipif(config.cxx == "", reason="Needs a C compiler.")
def test_composite_c_code(self):
"""Make sure this `Op`'s `c_code` works within a `Composite`."""
x = matrix("x")
mode = get_mode("FAST_RUN").including("local_ultra_fast_sigmoid")
f = aesara.function([x], sigmoid(x) + sigmoid(x + 1), mode=mode)
topo = f.maker.fgraph.toposort()
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, Composite)
assert ultra_fast_scalar_sigmoid in set(
node.op for node in topo[0].op.scalar_op.fgraph.toposort()
)
assert len(topo) == 1
def test_local_hard_sigmoid(self): def test_local_hard_sigmoid(self):
x = matrix("x") x = matrix("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论