Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fe0365ad
提交
fe0365ad
authored
10月 05, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement new JAX conversions for theano.tensor.extra_ops
上级
e464ba49
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
300 行增加
和
11 行删除
+300
-11
test_jax.py
tests/sandbox/test_jax.py
+56
-0
jaxify.py
theano/sandbox/jaxify.py
+244
-11
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
fe0365ad
...
@@ -664,3 +664,59 @@ def test_shared():
...
@@ -664,3 +664,59 @@ def test_shared():
jax_res
=
theano_jax_fn
()
jax_res
=
theano_jax_fn
()
assert
isinstance
(
jax_res
,
jax
.
interpreters
.
xla
.
DeviceArray
)
assert
isinstance
(
jax_res
,
jax
.
interpreters
.
xla
.
DeviceArray
)
np
.
testing
.
assert_allclose
(
jax_res
,
new_a_value
*
2
)
np
.
testing
.
assert_allclose
(
jax_res
,
new_a_value
*
2
)
def
test_extra_ops
():
a
=
tt
.
matrix
(
"a"
)
a
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
theano
.
config
.
floatX
)
.
reshape
((
3
,
2
))
out
=
tt
.
extra_ops
.
cumsum
(
a
,
axis
=
0
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
tt
.
extra_ops
.
cumprod
(
a
,
axis
=
1
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
tt
.
extra_ops
.
diff
(
a
,
n
=
2
,
axis
=
1
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
tt
.
extra_ops
.
repeat
(
a
,
(
3
,
3
),
axis
=
1
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
# This function also cannot take symbolic input.
c
=
tt
.
as_tensor
(
5
)
out
=
tt
.
extra_ops
.
bartlett
(
c
)
fgraph
=
theano
.
gof
.
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
with
pytest
.
raises
(
NotImplementedError
):
out
=
tt
.
extra_ops
.
fill_diagonal
(
a
,
c
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
with
pytest
.
raises
(
NotImplementedError
):
out
=
tt
.
extra_ops
.
fill_diagonal_offset
(
a
,
c
,
c
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
with
pytest
.
raises
(
NotImplementedError
):
out
=
tt
.
extra_ops
.
Unique
(
axis
=
1
)(
a
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
indices
=
np
.
arange
(
np
.
product
((
3
,
4
)))
out
=
tt
.
extra_ops
.
unravel_index
(
indices
,
(
3
,
4
),
order
=
"C"
)
fgraph
=
theano
.
gof
.
FunctionGraph
([],
out
)
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
],
must_be_device_array
=
False
)
multi_index
=
np
.
unravel_index
(
np
.
arange
(
np
.
product
((
3
,
4
))),
(
3
,
4
))
out
=
tt
.
extra_ops
.
ravel_multi_index
(
multi_index
,
(
3
,
4
))
fgraph
=
theano
.
gof
.
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
],
must_be_device_array
=
False
)
theano/sandbox/jaxify.py
浏览文件 @
fe0365ad
...
@@ -35,6 +35,7 @@ from theano.tensor.basic import (
...
@@ -35,6 +35,7 @@ from theano.tensor.basic import (
Alloc
,
Alloc
,
Reshape
,
Reshape
,
Join
,
Join
,
MaxAndArgmax
,
)
)
from
theano.scalar.basic
import
ScalarOp
,
Composite
,
Cast
,
Clip
,
Identity
from
theano.scalar.basic
import
ScalarOp
,
Composite
,
Cast
,
Clip
,
Identity
from
theano.tensor.elemwise
import
Elemwise
,
CAReduce
,
DimShuffle
from
theano.tensor.elemwise
import
Elemwise
,
CAReduce
,
DimShuffle
...
@@ -67,6 +68,21 @@ from theano.tensor.slinalg import (
...
@@ -67,6 +68,21 @@ from theano.tensor.slinalg import (
Solve
,
Solve
,
)
)
from
theano.tensor.type_other
import
MakeSlice
from
theano.tensor.extra_ops
import
(
CumOp
,
DiffOp
,
RepeatOp
,
Bartlett
,
FillDiagonal
,
FillDiagonalOffset
,
Unique
,
UnravelIndex
,
RavelMultiIndex
,
)
if
theano
.
config
.
floatX
==
"float64"
:
if
theano
.
config
.
floatX
==
"float64"
:
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
else
:
else
:
...
@@ -82,7 +98,7 @@ except AttributeError:
...
@@ -82,7 +98,7 @@ except AttributeError:
pass
pass
subtensor_ops
=
(
Subtensor
,
AdvancedSubtensor1
,
BaseAdvancedSubtensor
)
subtensor_ops
=
(
Subtensor
,
AdvancedSubtensor1
,
BaseAdvancedSubtensor
)
incsubtensor_ops
=
(
IncSubtensor
,
AdvancedIncSubtensor1
,
BaseAdvancedIncSubtensor
)
incsubtensor_ops
=
(
IncSubtensor
,
AdvancedIncSubtensor1
)
def
compose_jax_funcs
(
out_node
,
fgraph_inputs
,
memo
=
None
):
def
compose_jax_funcs
(
out_node
,
fgraph_inputs
,
memo
=
None
):
...
@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
...
@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
if
i
in
fgraph_inputs
:
if
i
in
fgraph_inputs
:
idx
=
fgraph_inputs
.
index
(
i
)
idx
=
fgraph_inputs
.
index
(
i
)
def
jax_inputs_func
(
*
inputs
,
i_dtype
=
i
.
dtype
,
idx
=
idx
):
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
def
jax_inputs_func
(
*
inputs
,
i_dtype
=
i_dtype
,
idx
=
idx
):
return
jnp
.
array
(
inputs
[
idx
],
dtype
=
jnp
.
dtype
(
i_dtype
))
return
jnp
.
array
(
inputs
[
idx
],
dtype
=
jnp
.
dtype
(
i_dtype
))
input_f
=
jax_inputs_func
input_f
=
jax_inputs_func
elif
i
.
owner
is
None
:
elif
i
.
owner
is
None
:
def
jax_data_func
(
*
inputs
,
i_dtype
=
i
.
dtype
,
i_data
=
i
.
data
):
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
return
jnp
.
array
(
i_data
,
dtype
=
jnp
.
dtype
(
i_dtype
))
i_data
=
i
.
data
def
jax_data_func
(
*
inputs
,
i_dtype
=
i_dtype
,
i_data
=
i_data
):
if
i_dtype
is
None
:
return
i_data
else
:
return
jnp
.
array
(
i_data
,
dtype
=
jnp
.
dtype
(
i_dtype
))
input_f
=
jax_data_func
input_f
=
jax_data_func
else
:
else
:
...
@@ -158,6 +182,14 @@ def jax_funcify(op):
...
@@ -158,6 +182,14 @@ def jax_funcify(op):
raise
NotImplementedError
(
"No JAX conversion for the given `Op`: {}"
.
format
(
op
))
raise
NotImplementedError
(
"No JAX conversion for the given `Op`: {}"
.
format
(
op
))
@jax_funcify.register
(
MakeSlice
)
def
jax_funcify_MakeSlice
(
op
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
@jax_funcify.register
(
ScalarOp
)
@jax_funcify.register
(
ScalarOp
)
def
jax_funcify_ScalarOp
(
op
):
def
jax_funcify_ScalarOp
(
op
):
func_name
=
op
.
nfunc_spec
[
0
]
func_name
=
op
.
nfunc_spec
[
0
]
...
@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op):
...
@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register
(
SpecifyShape
)
@jax_funcify.register
(
SpecifyShape
)
def
jax_funcify_SpecifyShape
(
op
):
def
jax_funcify_SpecifyShape
(
op
):
def
specifyshape
(
x
,
shape
):
def
specifyshape
(
x
,
shape
):
assert
x
.
ndim
==
shape
.
size
assert
x
.
ndim
==
len
(
shape
)
assert
jnp
.
all
(
x
.
shape
==
shape
),
(
"got shape"
,
x
.
shape
,
"expected"
,
shape
)
assert
jnp
.
all
(
x
.
shape
==
tuple
(
shape
)),
(
"got shape"
,
x
.
shape
,
"expected"
,
shape
,
)
return
x
return
x
return
specifyshape
return
specifyshape
...
@@ -475,11 +512,15 @@ def jax_funcify_Scan(op):
...
@@ -475,11 +512,15 @@ def jax_funcify_Scan(op):
@jax_funcify.register
(
IfElse
)
@jax_funcify.register
(
IfElse
)
def
jax_funcify_IfElse
(
op
):
def
jax_funcify_IfElse
(
op
):
def
ifelse
(
cond
,
*
args
):
n_outs
=
op
.
n_outs
def
ifelse
(
cond
,
*
args
,
n_outs
=
n_outs
):
if
cond
:
if
cond
:
re
turn
args
[:
op
.
n_outs
]
re
s
=
args
[:
n_outs
]
else
:
else
:
return
args
[
op
.
n_outs
:]
res
=
args
[
n_outs
:]
return
res
if
n_outs
>
1
else
res
[
0
]
return
ifelse
return
ifelse
...
@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
...
@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def
jax_funcify_IncSubtensor
(
op
):
def
jax_funcify_IncSubtensor
(
op
):
idx_list
=
op
.
idx_list
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
jax_fn
=
jax
.
ops
.
index_update
jax_fn
=
jax
.
ops
.
index_update
else
:
else
:
jax_fn
=
jax
.
ops
.
index_add
jax_fn
=
jax
.
ops
.
index_add
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
idx_list
):
_ilist
=
list
(
ilist
)
_ilist
=
list
(
ilist
)
cdata
=
tuple
(
convert_indices
(
_ilist
,
idx
)
for
idx
in
op
.
idx_list
)
cdata
=
tuple
(
convert_indices
(
_ilist
,
idx
)
for
idx
in
idx_list
)
if
len
(
cdata
)
==
1
:
if
len
(
cdata
)
==
1
:
cdata
=
cdata
[
0
]
cdata
=
cdata
[
0
]
...
@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op):
...
@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op):
_
=
[
jax_funcify
.
register
(
op
,
jax_funcify_IncSubtensor
)
for
op
in
incsubtensor_ops
]
_
=
[
jax_funcify
.
register
(
op
,
jax_funcify_IncSubtensor
)
for
op
in
incsubtensor_ops
]
@jax_funcify.register
(
BaseAdvancedIncSubtensor
)
def
jax_funcify_BaseAdvancedIncSubtensor
(
op
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
jax_fn
=
jax
.
ops
.
index_update
else
:
jax_fn
=
jax
.
ops
.
index_add
def
baseadvancedincsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
return
jax_fn
(
x
,
ilist
,
y
)
return
baseadvancedincsubtensor
@jax_funcify.register
(
FunctionGraph
)
@jax_funcify.register
(
FunctionGraph
)
def
jax_funcify_FunctionGraph
(
fgraph
):
def
jax_funcify_FunctionGraph
(
fgraph
):
...
@@ -656,6 +713,44 @@ def jax_funcify_Join(op):
...
@@ -656,6 +713,44 @@ def jax_funcify_Join(op):
return
join
return
join
@jax_funcify.register
(
MaxAndArgmax
)
def
jax_funcify_MaxAndArgmax
(
op
):
axis
=
op
.
axis
def
maxandargmax
(
x
,
axis
=
axis
):
if
axis
is
None
:
axes
=
tuple
(
range
(
x
.
ndim
))
else
:
axes
=
tuple
(
int
(
ax
)
for
ax
in
axis
)
max_res
=
jnp
.
max
(
x
,
axis
)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes
=
jnp
.
array
(
[
i
for
i
in
range
(
x
.
ndim
)
if
i
not
in
axes
],
dtype
=
"int64"
)
# Not-reduced axes in front
transposed_x
=
jnp
.
transpose
(
x
,
jnp
.
concatenate
((
keep_axes
,
jnp
.
array
(
axes
,
dtype
=
"int64"
)))
)
kept_shape
=
transposed_x
.
shape
[:
len
(
keep_axes
)]
reduced_shape
=
transposed_x
.
shape
[
len
(
keep_axes
)
:]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape
=
kept_shape
+
(
jnp
.
prod
(
jnp
.
array
(
reduced_shape
,
dtype
=
"int64"
),
dtype
=
"int64"
),
)
reshaped_x
=
transposed_x
.
reshape
(
new_shape
)
max_idx_res
=
jnp
.
argmax
(
reshaped_x
,
axis
=-
1
)
.
astype
(
"int64"
)
return
max_res
,
max_idx_res
return
maxandargmax
@jax_funcify.register
(
ExtractDiag
)
@jax_funcify.register
(
ExtractDiag
)
def
jax_funcify_ExtractDiag
(
op
):
def
jax_funcify_ExtractDiag
(
op
):
offset
=
op
.
offset
offset
=
op
.
offset
...
@@ -763,3 +858,141 @@ def jax_funcify_SVD(op):
...
@@ -763,3 +858,141 @@ def jax_funcify_SVD(op):
return
jnp
.
linalg
.
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
return
jnp
.
linalg
.
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
return
svd
return
svd
@jax_funcify.register
(
CumOp
)
def
jax_funcify_CumOp
(
op
):
axis
=
op
.
axis
mode
=
op
.
mode
def
cumop
(
x
,
axis
=
axis
,
mode
=
mode
):
if
mode
==
"add"
:
return
jnp
.
cumsum
(
x
,
axis
=
axis
)
else
:
return
jnp
.
cumprod
(
x
,
axis
=
axis
)
return
cumop
@jax_funcify.register
(
DiffOp
)
def
jax_funcify_DiffOp
(
op
):
n
=
op
.
n
axis
=
op
.
axis
def
diffop
(
x
,
n
=
n
,
axis
=
axis
):
return
jnp
.
diff
(
x
,
n
=
n
,
axis
=
axis
)
return
diffop
@jax_funcify.register
(
RepeatOp
)
def
jax_funcify_RepeatOp
(
op
):
axis
=
op
.
axis
def
repeatop
(
x
,
repeats
,
axis
=
axis
):
return
jnp
.
repeat
(
x
,
repeats
,
axis
=
axis
)
return
repeatop
@jax_funcify.register
(
Bartlett
)
def
jax_funcify_Bartlett
(
op
):
def
bartlett
(
x
):
return
jnp
.
bartlett
(
x
)
return
bartlett
@jax_funcify.register
(
FillDiagonal
)
def
jax_funcify_FillDiagonal
(
op
):
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise
NotImplementedError
(
"flatiter not implemented in JAX"
)
@jax_funcify.register
(
FillDiagonalOffset
)
def
jax_funcify_FillDiagonalOffset
(
op
):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise
NotImplementedError
(
"flatiter not implemented in JAX"
)
@jax_funcify.register
(
Unique
)
def
jax_funcify_Unique
(
op
):
return_index
=
op
.
return_index
return_inverse
=
op
.
return_inverse
return_counts
=
op
.
return_counts
axis
=
op
.
axis
def
unique
(
x
,
return_index
=
return_index
,
return_inverse
=
return_inverse
,
return_counts
=
return_counts
,
axis
=
axis
,
):
param
=
{}
if
return_index
:
param
[
"return_index"
]
=
True
if
return_inverse
:
param
[
"return_inverse"
]
=
True
if
return_counts
:
param
[
"return_counts"
]
=
True
if
axis
is
not
None
:
param
[
"axis"
]
=
axis
return
jnp
.
unique
(
x
,
**
param
)
return
unique
@jax_funcify.register
(
UnravelIndex
)
def
jax_funcify_UnravelIndex
(
op
):
order
=
op
.
order
warn
(
"JAX ignores the `order` parameter in `unravel_index`."
)
def
unravelindex
(
indices
,
dims
,
order
=
order
):
return
jnp
.
unravel_index
(
indices
,
dims
)
return
unravelindex
@jax_funcify.register
(
RavelMultiIndex
)
def
jax_funcify_RavelMultiIndex
(
op
):
mode
=
op
.
mode
order
=
op
.
order
def
ravelmultiindex
(
*
inp
,
mode
=
mode
,
order
=
order
):
multi_index
,
dims
=
inp
[:
-
1
],
inp
[
-
1
]
return
jnp
.
ravel_multi_index
(
multi_index
,
dims
,
mode
=
mode
,
order
=
order
)
return
ravelmultiindex
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论