提交 0d658a6d authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Move Sigmoid components to respective modules

上级 da04eff0
...@@ -1515,7 +1515,7 @@ class ProfileStats: ...@@ -1515,7 +1515,7 @@ class ProfileStats:
from aesara import scalar as aes from aesara import scalar as aes
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot from aesara.tensor.math import Dot
from aesara.tensor.nnet.sigm import ScalarSigmoid, ScalarSoftplus from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
scalar_op_amdlibm_no_speed_up = [ scalar_op_amdlibm_no_speed_up = [
...@@ -1567,7 +1567,7 @@ class ProfileStats: ...@@ -1567,7 +1567,7 @@ class ProfileStats:
aes.Tanh, aes.Tanh,
aes.Cosh, aes.Cosh,
aes.Sinh, aes.Sinh,
ScalarSigmoid, aes.Sigmoid,
ScalarSoftplus, ScalarSoftplus,
] ]
......
...@@ -16,6 +16,7 @@ from aesara.scalar.basic import ( ...@@ -16,6 +16,7 @@ from aesara.scalar.basic import (
complex_types, complex_types,
discrete_types, discrete_types,
exp, exp,
float64,
float_types, float_types,
upcast, upcast,
upgrade_to_float, upgrade_to_float,
...@@ -928,3 +929,49 @@ class I0(UnaryScalarOp): ...@@ -928,3 +929,49 @@ class I0(UnaryScalarOp):
i0 = I0(upgrade_to_float, name="i0") i0 = I0(upgrade_to_float, name="i0")
class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit
"""
nfunc_spec = ("scipy.special.expit", 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.expit(x)
else:
super().impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
y = sigmoid(x)
rval = gz * y * (1.0 - y)
assert rval.type.dtype.find("float") != -1
return [rval]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return f"""{z} = 1.0 / (1.0 + exp(-{x}));"""
else:
return f"""{z} = 1.0f / (1.0f + exp(-{x}));"""
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
...@@ -33,6 +33,7 @@ from aesara.tensor.math import ( ...@@ -33,6 +33,7 @@ from aesara.tensor.math import (
rad2deg, rad2deg,
round_half_to_even, round_half_to_even,
sgn, sgn,
sigmoid,
sin, sin,
sinh, sinh,
sqr, sqr,
...@@ -3189,7 +3190,7 @@ def structured_monoid(tensor_op): ...@@ -3189,7 +3190,7 @@ def structured_monoid(tensor_op):
return decorator return decorator
@structured_monoid(aesara.tensor.nnet.sigmoid) @structured_monoid(sigmoid)
def structured_sigmoid(x): def structured_sigmoid(x):
""" """
Structured elemwise sigmoid. Structured elemwise sigmoid.
......
...@@ -308,6 +308,11 @@ def iv_inplace(v, x): ...@@ -308,6 +308,11 @@ def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real).""" """Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise @scalar_elemwise
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -1402,6 +1402,11 @@ def iv(v, x): ...@@ -1402,6 +1402,11 @@ def iv(v, x):
"""Modified Bessel function of the first kind of order v (real).""" """Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise @scalar_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
...@@ -2828,6 +2833,7 @@ __all__ = [ ...@@ -2828,6 +2833,7 @@ __all__ = [
"i0", "i0",
"i1", "i1",
"iv", "iv",
"sigmoid",
"real", "real",
"imag", "imag",
"angle", "angle",
......
...@@ -43,11 +43,4 @@ from aesara.tensor.nnet.basic import ( ...@@ -43,11 +43,4 @@ from aesara.tensor.nnet.basic import (
softsign, softsign,
) )
from aesara.tensor.nnet.batchnorm import batch_normalization from aesara.tensor.nnet.batchnorm import batch_normalization
from aesara.tensor.nnet.sigm import ( from aesara.tensor.nnet.sigm import hard_sigmoid, softplus, ultra_fast_sigmoid
hard_sigmoid,
scalar_sigmoid,
sigmoid,
sigmoid_inplace,
softplus,
ultra_fast_sigmoid,
)
...@@ -53,11 +53,12 @@ from aesara.tensor.math import ( ...@@ -53,11 +53,12 @@ from aesara.tensor.math import (
mul, mul,
neg, neg,
or_, or_,
sigmoid,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, tensordot, true_div from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.nnet.sigm import sigmoid, softplus from aesara.tensor.nnet.sigm import softplus
from aesara.tensor.shape import shape, shape_padleft from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
......
...@@ -16,79 +16,25 @@ from aesara import scalar as aes ...@@ -16,79 +16,25 @@ from aesara import scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.opt import PatternSub, copy_stack_trace, local_optimizer from aesara.graph.opt import PatternSub, copy_stack_trace, local_optimizer
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.tensor import basic_opt from aesara.tensor import basic_opt
from aesara.tensor.basic import constant, get_scalar_constant_value from aesara.tensor.basic import constant, get_scalar_constant_value
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import add, clip, exp, inv, log, log1p, mul, neg, sub, true_div from aesara.tensor.math import (
from aesara.tensor.type import TensorType, values_eq_approx_remove_inf add,
clip,
exp,
imported_scipy_special = False inv,
try: log,
import scipy.special log1p,
import scipy.stats mul,
neg,
imported_scipy_special = True sigmoid,
# Importing scipy.special may raise ValueError. sub,
# See http://projects.scipy.org/scipy/ticket/1739 true_div,
except (ImportError, ValueError):
pass
class ScalarSigmoid(aes.UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit
"""
nfunc_spec = ("scipy.special.expit", 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.expit(x)
else:
super().impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
y = scalar_sigmoid(x)
rval = gz * y * (1.0 - y)
assert rval.type.dtype.find("float") != -1
return [rval]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
if node.inputs[0].type == aes.float32 or node.inputs[0].type == aes.float16:
return f"""{z} = 1.0f / (1.0f + exp(-{x}));"""
elif node.inputs[0].type == aes.float64:
return f"""{z} = 1.0 / (1.0 + exp(-{x}));"""
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
scalar_sigmoid = ScalarSigmoid(aes.upgrade_to_float, name="scalar_sigmoid")
sigmoid = Elemwise(scalar_sigmoid, name="sigmoid")
sigmoid_inplace = Elemwise(
ScalarSigmoid(aes.transfer_type(0)),
inplace_pattern={0: 0},
name="sigmoid_inplace",
) )
from aesara.tensor.type import TensorType, values_eq_approx_remove_inf
pprint.assign(sigmoid, printing.FunctionPrinter("sigmoid"))
class UltraFastScalarSigmoid(aes.UnaryScalarOp): class UltraFastScalarSigmoid(aes.UnaryScalarOp):
...@@ -258,7 +204,6 @@ class ScalarSoftplus(aes.UnaryScalarOp): ...@@ -258,7 +204,6 @@ class ScalarSoftplus(aes.UnaryScalarOp):
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
""" """
@staticmethod
def static_impl(x): def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in # If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
......
...@@ -12,8 +12,8 @@ from aesara.gradient import DisconnectedType, Rop, grad ...@@ -12,8 +12,8 @@ from aesara.gradient import DisconnectedType, Rop, grad
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.tensor.math import dot, exp from aesara.tensor.math import dot, exp
from aesara.tensor.math import round as aet_round from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import sigmoid
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools from tests import unittest_tools
......
...@@ -3,9 +3,8 @@ import numpy as np ...@@ -3,9 +3,8 @@ import numpy as np
from aesara.compile.function.pfunc import pfunc from aesara.compile.function.pfunc import pfunc
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.gradient import grad from aesara.gradient import grad
from aesara.tensor.math import dot from aesara.tensor.math import dot, sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import sigmoid
from aesara.tensor.type import dvector from aesara.tensor.type import dvector
......
...@@ -19,7 +19,7 @@ class Mlp: ...@@ -19,7 +19,7 @@ class Mlp:
x = dmatrix("x") x = dmatrix("x")
wh = aesara.shared(self.rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True) wh = aesara.shared(self.rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
bh = aesara.shared(np.zeros(nhiddens), borrow=True) bh = aesara.shared(np.zeros(nhiddens), borrow=True)
h = aesara.tensor.nnet.sigmoid(aet.dot(x, wh) + bh) h = aesara.tensor.sigmoid(aet.dot(x, wh) + bh)
wy = aesara.shared(self.rng.normal(0, 1, (nhiddens, noutputs))) wy = aesara.shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = aesara.shared(np.zeros(noutputs), borrow=True) by = aesara.shared(np.zeros(noutputs), borrow=True)
...@@ -45,7 +45,7 @@ class OfgNested: ...@@ -45,7 +45,7 @@ class OfgNested:
class Ofg: class Ofg:
def __init__(self): def __init__(self):
x, y, z = scalars("xyz") x, y, z = scalars("xyz")
e = aesara.tensor.nnet.sigmoid((x + y + z) ** 2) e = aesara.tensor.sigmoid((x + y + z) ** 2)
op = aesara.compile.builders.OpFromGraph([x, y, z], [e]) op = aesara.compile.builders.OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
...@@ -56,7 +56,7 @@ class Ofg: ...@@ -56,7 +56,7 @@ class Ofg:
class OfgSimple: class OfgSimple:
def __init__(self): def __init__(self):
x, y, z = scalars("xyz") x, y, z = scalars("xyz")
e = aesara.tensor.nnet.sigmoid((x + y + z) ** 2) e = aesara.tensor.sigmoid((x + y + z) ** 2)
op = aesara.compile.builders.OpFromGraph([x, y, z], [e]) op = aesara.compile.builders.OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) e2 = op(x, y, z)
......
import numpy as np import numpy as np
import aesara import aesara
from aesara.tensor import nnet from aesara.tensor.math import dot, sigmoid, tanh
from aesara.tensor.math import dot, tanh
class Model: class Model:
...@@ -126,10 +125,10 @@ class GRU(Layer): ...@@ -126,10 +125,10 @@ class GRU(Layer):
"""step through processed input to create output""" """step through processed input to create output"""
def step(inp, s_prev): def step(inp, s_prev):
i_t = nnet.sigmoid( i_t = sigmoid(
dot(inp, self.W_i) + dot(s_prev, self.R_i) + self.b_wi + self.b_ru dot(inp, self.W_i) + dot(s_prev, self.R_i) + self.b_wi + self.b_ru
) )
r_t = nnet.sigmoid( r_t = sigmoid(
dot(inp, self.W_r) + dot(s_prev, self.R_r) + self.b_wr + self.b_rr dot(inp, self.W_r) + dot(s_prev, self.R_r) + self.b_wr + self.b_rr
) )
...@@ -230,13 +229,13 @@ class LSTM(Layer): ...@@ -230,13 +229,13 @@ class LSTM(Layer):
"""step through processed input to create output""" """step through processed input to create output"""
def step(x_t, h_tm1, c_tm1): def step(x_t, h_tm1, c_tm1):
i_t = nnet.sigmoid( i_t = sigmoid(
dot(x_t, self.W_i) + dot(h_tm1, self.R_i) + self.b_wi + self.b_ri dot(x_t, self.W_i) + dot(h_tm1, self.R_i) + self.b_wi + self.b_ri
) )
f_t = nnet.sigmoid( f_t = sigmoid(
dot(x_t, self.W_f) + dot(h_tm1, self.R_f) + self.b_wf + self.b_rf dot(x_t, self.W_f) + dot(h_tm1, self.R_f) + self.b_wf + self.b_rf
) )
o_t = nnet.sigmoid( o_t = sigmoid(
dot(x_t, self.W_o) + dot(h_tm1, self.R_o) + self.b_ro + self.b_wo dot(x_t, self.W_o) + dot(h_tm1, self.R_o) + self.b_ro + self.b_wo
) )
......
...@@ -31,7 +31,7 @@ from aesara.tensor.math import MaxAndArgmax ...@@ -31,7 +31,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log from aesara.tensor.math import clip, cosh, gammaln, log
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod from aesara.tensor.math import maximum, prod, sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import RandomVariable, normal from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
...@@ -944,7 +944,7 @@ def test_nnet(): ...@@ -944,7 +944,7 @@ def test_nnet():
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
out = aet_nnet.sigmoid(x) out = sigmoid(x)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -269,7 +269,7 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -269,7 +269,7 @@ def test_create_numba_signature(v, expected, force_scalar):
( (
[aet.vector()], [aet.vector()],
[np.random.randn(100).astype(config.floatX)], [np.random.randn(100).astype(config.floatX)],
lambda x: aet.nnet.sigmoid(x), lambda x: aet.sigmoid(x),
), ),
( (
[aet.vector() for i in range(4)], [aet.vector() for i in range(4)],
......
...@@ -52,10 +52,10 @@ from aesara.tensor.blas import Dot22 ...@@ -52,10 +52,10 @@ from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot from aesara.tensor.math import Dot
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import dot, mean from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy, sigmoid, softmax_graph from aesara.tensor.nnet import categorical_crossentropy, softmax_graph
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, shape, specify_shape from aesara.tensor.shape import Shape_i, reshape, shape, specify_shape
from aesara.tensor.sharedvar import SharedVariable from aesara.tensor.sharedvar import SharedVariable
......
...@@ -5,9 +5,8 @@ import aesara.tensor.basic as aet ...@@ -5,9 +5,8 @@ import aesara.tensor.basic as aet
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import Rop, grad, jacobian from aesara.gradient import Rop, grad, jacobian
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.tensor import nnet
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot, dot from aesara.tensor.math import Dot, dot, sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.type import matrix, tensor3, vector from aesara.tensor.type import matrix, tensor3, vector
...@@ -322,8 +321,8 @@ class TestPushOutSumOfDot: ...@@ -322,8 +321,8 @@ class TestPushOutSumOfDot:
): ):
pre_r = ri + h.dot(U) pre_r = ri + h.dot(U)
pre_z = zi + h.dot(V) pre_z = zi + h.dot(V)
r = nnet.sigmoid(pre_r) r = sigmoid(pre_r)
z = nnet.sigmoid(pre_z) z = sigmoid(pre_z)
after_r = r * h after_r = r * h
pre_h = x + after_r.dot(W) pre_h = x + after_r.dot(W)
......
...@@ -9,7 +9,17 @@ from aesara.gradient import grad ...@@ -9,7 +9,17 @@ from aesara.gradient import grad
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Argmax, add, argmax, dot, exp, log, max_and_argmax, mean from aesara.tensor.math import (
Argmax,
add,
argmax,
dot,
exp,
log,
max_and_argmax,
mean,
sigmoid,
)
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, true_div from aesara.tensor.math import tanh, true_div
from aesara.tensor.nnet.basic import ( from aesara.tensor.nnet.basic import (
...@@ -45,7 +55,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -45,7 +55,6 @@ from aesara.tensor.nnet.basic import (
softmax_with_bias, softmax_with_bias,
softsign, softsign,
) )
from aesara.tensor.nnet.sigm import sigmoid
from aesara.tensor.shape import shape_padleft, specify_shape from aesara.tensor.shape import shape_padleft, specify_shape
from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
......
...@@ -5,7 +5,8 @@ import aesara.tensor as aet ...@@ -5,7 +5,8 @@ import aesara.tensor as aet
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.graph.toolbox import is_same_graph from aesara.graph.toolbox import is_same_graph
from aesara.tensor.inplace import neg_inplace from aesara.tensor import sigmoid
from aesara.tensor.inplace import neg_inplace, sigmoid_inplace
from aesara.tensor.math import clip, exp, log, mul, neg from aesara.tensor.math import clip, exp, log, mul, neg
from aesara.tensor.nnet.sigm import ( from aesara.tensor.nnet.sigm import (
ScalarSoftplus, ScalarSoftplus,
...@@ -15,8 +16,6 @@ from aesara.tensor.nnet.sigm import ( ...@@ -15,8 +16,6 @@ from aesara.tensor.nnet.sigm import (
parse_mul_tree, parse_mul_tree,
perform_sigm_times_exp, perform_sigm_times_exp,
register_local_1msigmoid, register_local_1msigmoid,
sigmoid,
sigmoid_inplace,
simplify_mul, simplify_mul,
softplus, softplus,
ultra_fast_sigmoid, ultra_fast_sigmoid,
......
...@@ -61,8 +61,7 @@ from aesara.tensor.blas import ( ...@@ -61,8 +61,7 @@ from aesara.tensor.blas import (
res_is_a, res_is_a,
) )
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import Dot, dot, mean, mul, neg, outer, sqrt from aesara.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt
from aesara.tensor.nnet import sigmoid
from aesara.tensor.type import ( from aesara.tensor.type import (
cmatrix, cmatrix,
col, col,
......
...@@ -33,7 +33,7 @@ from aesara.graph.basic import Apply, graph_inputs ...@@ -33,7 +33,7 @@ from aesara.graph.basic import Apply, graph_inputs
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.sandbox.rng_mrg import MRG_RandomStream from aesara.sandbox.rng_mrg import MRG_RandomStream
from aesara.tensor.math import add, dot, exp, sqr from aesara.tensor.math import add, dot, exp, sigmoid, sqr
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -466,7 +466,7 @@ class TestGrad: ...@@ -466,7 +466,7 @@ class TestGrad:
def make_grad_func(X): def make_grad_func(X):
Z = dot(X, W) + b Z = dot(X, W) + b
H = aesara.tensor.nnet.sigmoid(Z) H = sigmoid(Z)
cost = H.sum() cost = H.sum()
g = grad(cost, X) g = grad(cost, X)
return aesara.function([X, W, b], g, on_unused_input="ignore") return aesara.function([X, W, b], g, on_unused_input="ignore")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论