Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
92d5450f
提交
92d5450f
authored
11月 16, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement Polygamma Op
上级
f6be5213
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
159 行增加
和
9 行删除
+159
-9
gradient.py
pytensor/gradient.py
+1
-1
math.py
pytensor/scalar/math.py
+53
-4
math.py
pytensor/tensor/math.py
+6
-0
math.py
pytensor/tensor/rewriting/math.py
+22
-2
test_scalar.py
tests/link/jax/test_scalar.py
+15
-0
test_math.py
tests/tensor/rewriting/test_math.py
+18
-2
test_math.py
tests/tensor/test_math.py
+44
-0
没有找到文件。
pytensor/gradient.py
浏览文件 @
92d5450f
...
...
@@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""):
return
(
NullType
(
"This variable is Null because the grad method for "
f
"input {x_pos} ({x}) of the {op} op is
not implement
ed. {comment}"
f
"input {x_pos} ({x}) of the {op} op is
undefin
ed. {comment}"
)
)()
...
...
pytensor/scalar/math.py
浏览文件 @
92d5450f
...
...
@@ -13,7 +13,7 @@ import scipy.special
import
scipy.stats
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
grad_not_implemented
from
pytensor.gradient
import
grad_not_implemented
,
grad_undefined
from
pytensor.scalar.basic
import
BinaryScalarOp
,
ScalarOp
,
UnaryScalarOp
from
pytensor.scalar.basic
import
abs
as
scalar_abs
from
pytensor.scalar.basic
import
(
...
...
@@ -473,8 +473,12 @@ class TriGamma(UnaryScalarOp):
def
impl
(
self
,
x
):
return
TriGamma
.
st_impl
(
x
)
def
grad
(
self
,
inputs
,
outputs_gradients
):
raise
NotImplementedError
()
def
L_op
(
self
,
inputs
,
outputs
,
outputs_gradients
):
(
x
,)
=
inputs
(
g_out
,)
=
outputs_gradients
if
x
in
complex_types
:
raise
NotImplementedError
(
"gradient not implemented for complex types"
)
return
[
g_out
*
polygamma
(
2
,
x
)]
def
c_support_code
(
self
,
**
kwargs
):
# The implementation has been copied from
...
...
@@ -541,7 +545,52 @@ class TriGamma(UnaryScalarOp):
raise
NotImplementedError
(
"only floating point is implemented"
)
tri_gamma
=
TriGamma
(
upgrade_to_float
,
name
=
"tri_gamma"
)
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma
=
TriGamma
(
upgrade_to_float_no_complex
,
name
=
"tri_gamma"
)
class
PolyGamma
(
BinaryScalarOp
):
"""Polygamma function of order n evaluated at x.
It corresponds to the (n+1)th derivative of the log gamma function.
TODO: Because the first input is discrete and the output is continuous,
the default elemwise inplace won't work, as it always tries to store the results in the first input.
"""
nfunc_spec
=
(
"scipy.special.polygamma"
,
2
,
1
)
@staticmethod
def
output_types_preference
(
n_type
,
x_type
):
if
n_type
not
in
discrete_types
:
raise
TypeError
(
f
"Polygamma order parameter must be discrete, got {n_type} dtype"
)
# Scipy doesn't support it
return
upgrade_to_float_no_complex
(
x_type
)
@staticmethod
def
st_impl
(
n
,
x
):
return
scipy
.
special
.
polygamma
(
n
,
x
)
def
impl
(
self
,
n
,
x
):
return
PolyGamma
.
st_impl
(
n
,
x
)
def
L_op
(
self
,
inputs
,
outputs
,
output_gradients
):
(
n
,
x
)
=
inputs
(
g_out
,)
=
output_gradients
if
x
in
complex_types
:
raise
NotImplementedError
(
"gradient not implemented for complex types"
)
return
[
grad_undefined
(
self
,
0
,
n
),
g_out
*
self
(
n
+
1
,
x
),
]
def
c_code
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
polygamma
=
PolyGamma
(
name
=
"polygamma"
)
class
Chi2SF
(
BinaryScalarOp
):
...
...
pytensor/tensor/math.py
浏览文件 @
92d5450f
...
...
@@ -1369,6 +1369,11 @@ def tri_gamma(a):
"""second derivative of the log gamma function"""
@scalar_elemwise
def
polygamma
(
n
,
x
):
"""Polygamma function of order n evaluated at x"""
@scalar_elemwise
def
chi2sf
(
x
,
k
):
"""chi squared survival function"""
...
...
@@ -3008,6 +3013,7 @@ __all__ = [
"psi"
,
"digamma"
,
"tri_gamma"
,
"polygamma"
,
"chi2sf"
,
"gammainc"
,
"gammaincc"
,
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
92d5450f
...
...
@@ -52,6 +52,7 @@ from pytensor.tensor.math import (
from
pytensor.tensor.math
import
abs
as
at_abs
from
pytensor.tensor.math
import
(
add
,
digamma
,
dot
,
eq
,
erf
,
...
...
@@ -68,7 +69,7 @@ from pytensor.tensor.math import (
makeKeepDims
,
)
from
pytensor.tensor.math
import
max
as
at_max
from
pytensor.tensor.math
import
maximum
,
mul
,
neg
from
pytensor.tensor.math
import
maximum
,
mul
,
neg
,
polygamma
from
pytensor.tensor.math
import
pow
as
at_pow
from
pytensor.tensor.math
import
(
prod
,
...
...
@@ -81,7 +82,7 @@ from pytensor.tensor.math import (
sub
,
)
from
pytensor.tensor.math
import
sum
as
at_sum
from
pytensor.tensor.math
import
true_div
from
pytensor.tensor.math
import
tr
i_gamma
,
tr
ue_div
from
pytensor.tensor.rewriting.basic
import
(
alloc_like
,
broadcasted_by
,
...
...
@@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node):
x
=
node
.
inputs
[
0
]
if
x
.
type
.
dtype
not
in
complex_dtypes
:
return
[
x
]
local_polygamma_to_digamma
=
PatternNodeRewriter
(
(
polygamma
,
0
,
"x"
),
(
digamma
,
"x"
),
allow_multiple_clients
=
True
,
name
=
"local_polygamma_to_digamma"
,
)
register_specialize
(
local_polygamma_to_digamma
)
local_polygamma_to_tri_gamma
=
PatternNodeRewriter
(
(
polygamma
,
1
,
"x"
),
(
tri_gamma
,
"x"
),
allow_multiple_clients
=
True
,
name
=
"local_polygamma_to_tri_gamma"
,
)
register_specialize
(
local_polygamma_to_tri_gamma
)
tests/link/jax/test_scalar.py
浏览文件 @
92d5450f
...
...
@@ -20,6 +20,7 @@ from pytensor.tensor.math import (
iv
,
log
,
log1mexp
,
polygamma
,
psi
,
sigmoid
,
softplus
,
...
...
@@ -178,6 +179,20 @@ def test_tri_gamma():
compare_jax_and_py
(
fg
,
[
np
.
array
([
3.0
,
5.0
])])
def
test_polygamma
():
n
=
vector
(
"n"
,
dtype
=
"int32"
)
x
=
vector
(
"x"
,
dtype
=
"float32"
)
out
=
polygamma
(
n
,
x
)
fg
=
FunctionGraph
([
n
,
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
0
,
1
,
2
])
.
astype
(
"int32"
),
np
.
array
([
0.5
,
0.9
,
2.5
])
.
astype
(
"float32"
),
],
)
def
test_log1mexp
():
x
=
vector
(
"x"
)
out
=
log1mexp
(
x
)
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
92d5450f
...
...
@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.graph.rewriting.utils
import
is_same_graph
,
rewrite_graph
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.printing
import
debugprint
from
pytensor.scalar
import
Po
w
from
pytensor.scalar
import
Po
lyGamma
,
Pow
,
Psi
,
TriGamma
from
pytensor.tensor
import
inplace
from
pytensor.tensor.basic
import
Alloc
,
constant
,
join
,
second
,
switch
from
pytensor.tensor.blas
import
Dot22
,
Gemv
...
...
@@ -69,7 +69,7 @@ from pytensor.tensor.math import (
from
pytensor.tensor.math
import
max
as
at_max
from
pytensor.tensor.math
import
maximum
from
pytensor.tensor.math
import
min
as
at_min
from
pytensor.tensor.math
import
minimum
,
mul
,
neg
,
neq
from
pytensor.tensor.math
import
minimum
,
mul
,
neg
,
neq
,
polygamma
from
pytensor.tensor.math
import
pow
as
pt_pow
from
pytensor.tensor.math
import
(
prod
,
...
...
@@ -4236,3 +4236,19 @@ def test_logdiffexp():
np
.
testing
.
assert_almost_equal
(
f
(
x_test
,
y_test
),
np
.
log
(
np
.
exp
(
x_test
)
-
np
.
exp
(
y_test
))
)
def
test_polygamma_specialization
():
x
=
vector
(
"x"
)
y1
=
polygamma
(
0
,
x
)
y2
=
polygamma
(
1
,
x
)
y3
=
polygamma
(
2
,
x
)
fn
=
pytensor
.
function
(
[
x
],
[
y1
,
y2
,
y3
],
mode
=
get_default_mode
()
.
including
(
"specialize"
)
)
fn_outs
=
fn
.
maker
.
fgraph
.
outputs
assert
isinstance
(
fn_outs
[
0
]
.
owner
.
op
.
scalar_op
,
Psi
)
assert
isinstance
(
fn_outs
[
1
]
.
owner
.
op
.
scalar_op
,
TriGamma
)
assert
isinstance
(
fn_outs
[
2
]
.
owner
.
op
.
scalar_op
,
PolyGamma
)
tests/tensor/test_math.py
浏览文件 @
92d5450f
...
...
@@ -7,6 +7,7 @@ from itertools import product
import
numpy
as
np
import
pytest
import
scipy.special
from
numpy.testing
import
assert_array_equal
from
scipy.special
import
logsumexp
as
scipy_logsumexp
...
...
@@ -64,6 +65,7 @@ from pytensor.tensor.math import (
cov
,
deg2rad
,
dense_dot
,
digamma
,
dot
,
eq
,
exp
,
...
...
@@ -93,6 +95,7 @@ from pytensor.tensor.math import (
neg
,
neq
,
outer
,
polygamma
,
power
,
ptp
,
rad2deg
,
...
...
@@ -3470,3 +3473,44 @@ class TestMatMul:
fn
=
function
([
x
,
y
],
x
@
y
,
mode
=
"FAST_RUN"
)
[
node
]
=
fn
.
maker
.
fgraph
.
apply_nodes
assert
isinstance
(
node
.
op
,
Dot22
)
class
TestPolyGamma
:
def
test_basic
(
self
):
n
=
vector
(
"n"
,
dtype
=
"int64"
)
x
=
scalar
(
"x"
)
np
.
testing
.
assert_allclose
(
polygamma
(
n
,
x
)
.
eval
({
n
:
[
0
,
1
],
x
:
0.5
}),
scipy
.
special
.
polygamma
([
0
,
1
],
0.5
),
)
def
test_continuous_n_raises
(
self
):
n
=
scalar
(
"n"
,
dtype
=
"float64"
)
with
pytest
.
raises
(
TypeError
,
match
=
"must be discrete"
):
polygamma
(
n
,
0.5
)
def
test_complex_x_raises
(
self
):
x
=
scalar
(
dtype
=
"complex128"
)
with
pytest
.
raises
(
TypeError
,
match
=
"complex argument not supported"
):
polygamma
(
0
,
x
)
def
test_output_dtype
(
self
):
n
=
scalar
(
"n"
,
dtype
=
"int64"
)
polygamma
(
n
,
scalar
(
"x"
,
dtype
=
"float32"
))
.
dtype
==
"float32"
polygamma
(
n
,
scalar
(
"x"
,
dtype
=
"float64"
))
.
dtype
==
"float64"
polygamma
(
n
,
scalar
(
"x"
,
dtype
=
"int32"
))
.
dtype
==
"float64"
def
test_grad_x
(
self
):
x
=
scalar
(
"x"
)
op_grad
=
grad
(
polygamma
(
0
,
x
),
wrt
=
x
)
ref_grad
=
grad
(
digamma
(
x
),
wrt
=
x
)
np
.
testing
.
assert_allclose
(
op_grad
.
eval
({
x
:
0.9
}),
ref_grad
.
eval
({
x
:
0.9
}),
)
def
test_grad_n_undefined
(
self
):
n
=
scalar
(
dtype
=
"int64"
)
with
pytest
.
raises
(
NullTypeGradError
):
grad
(
polygamma
(
n
,
0.5
),
wrt
=
n
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论