Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e6914913
提交
e6914913
authored
4月 12, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use jax.numpy.vectorize for Elemwise Composite Ops
上级
4fa10665
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
85 行增加
和
75 行删除
+85
-75
dispatch.py
aesara/link/jax/dispatch.py
+67
-67
test_jax.py
tests/link/test_jax.py
+18
-8
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
e6914913
...
@@ -131,7 +131,7 @@ def jax_funcify(op, **kwargs):
...
@@ -131,7 +131,7 @@ def jax_funcify(op, **kwargs):
@jax_funcify.register
(
MakeSlice
)
@jax_funcify.register
(
MakeSlice
)
def
jax_funcify_MakeSlice
(
op
):
def
jax_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
slice
(
*
x
)
...
@@ -139,7 +139,7 @@ def jax_funcify_MakeSlice(op):
...
@@ -139,7 +139,7 @@ def jax_funcify_MakeSlice(op):
@jax_funcify.register
(
ScalarOp
)
@jax_funcify.register
(
ScalarOp
)
def
jax_funcify_ScalarOp
(
op
):
def
jax_funcify_ScalarOp
(
op
,
**
kwargs
):
func_name
=
op
.
nfunc_spec
[
0
]
func_name
=
op
.
nfunc_spec
[
0
]
if
"."
in
func_name
:
if
"."
in
func_name
:
...
@@ -168,7 +168,7 @@ def jax_funcify_ScalarOp(op):
...
@@ -168,7 +168,7 @@ def jax_funcify_ScalarOp(op):
@jax_funcify.register
(
Clip
)
@jax_funcify.register
(
Clip
)
def
jax_funcify_Clip
(
op
):
def
jax_funcify_Clip
(
op
,
**
kwargs
):
def
clip
(
x
,
min
,
max
):
def
clip
(
x
,
min
,
max
):
return
jnp
.
where
(
x
<
min
,
min
,
jnp
.
where
(
x
>
max
,
max
,
x
))
return
jnp
.
where
(
x
<
min
,
min
,
jnp
.
where
(
x
>
max
,
max
,
x
))
...
@@ -176,7 +176,7 @@ def jax_funcify_Clip(op):
...
@@ -176,7 +176,7 @@ def jax_funcify_Clip(op):
@jax_funcify.register
(
Identity
)
@jax_funcify.register
(
Identity
)
def
jax_funcify_Identity
(
op
):
def
jax_funcify_Identity
(
op
,
**
kwargs
):
def
identity
(
x
):
def
identity
(
x
):
return
x
return
x
...
@@ -184,7 +184,7 @@ def jax_funcify_Identity(op):
...
@@ -184,7 +184,7 @@ def jax_funcify_Identity(op):
@jax_funcify.register
(
Softmax
)
@jax_funcify.register
(
Softmax
)
def
jax_funcify_Softmax
(
op
):
def
jax_funcify_Softmax
(
op
,
**
kwargs
):
def
softmax
(
x
):
def
softmax
(
x
):
return
jax
.
nn
.
softmax
(
x
)
return
jax
.
nn
.
softmax
(
x
)
...
@@ -192,7 +192,7 @@ def jax_funcify_Softmax(op):
...
@@ -192,7 +192,7 @@ def jax_funcify_Softmax(op):
@jax_funcify.register
(
LogSoftmax
)
@jax_funcify.register
(
LogSoftmax
)
def
jax_funcify_LogSoftmax
(
op
):
def
jax_funcify_LogSoftmax
(
op
,
**
kwargs
):
def
log_softmax
(
x
):
def
log_softmax
(
x
):
return
jax
.
nn
.
log_softmax
(
x
)
return
jax
.
nn
.
log_softmax
(
x
)
...
@@ -200,7 +200,7 @@ def jax_funcify_LogSoftmax(op):
...
@@ -200,7 +200,7 @@ def jax_funcify_LogSoftmax(op):
@jax_funcify.register
(
ScalarSoftplus
)
@jax_funcify.register
(
ScalarSoftplus
)
def
jax_funcify_ScalarSoftplus
(
op
):
def
jax_funcify_ScalarSoftplus
(
op
,
**
kwargs
):
def
scalarsoftplus
(
x
):
def
scalarsoftplus
(
x
):
return
jnp
.
where
(
x
<
-
30.0
,
0.0
,
jnp
.
where
(
x
>
30.0
,
x
,
jnp
.
log1p
(
jnp
.
exp
(
x
))))
return
jnp
.
where
(
x
<
-
30.0
,
0.0
,
jnp
.
where
(
x
>
30.0
,
x
,
jnp
.
log1p
(
jnp
.
exp
(
x
))))
...
@@ -208,7 +208,7 @@ def jax_funcify_ScalarSoftplus(op):
...
@@ -208,7 +208,7 @@ def jax_funcify_ScalarSoftplus(op):
@jax_funcify.register
(
Second
)
@jax_funcify.register
(
Second
)
def
jax_funcify_Second
(
op
):
def
jax_funcify_Second
(
op
,
**
kwargs
):
def
second
(
x
,
y
):
def
second
(
x
,
y
):
return
jnp
.
broadcast_to
(
y
,
x
.
shape
)
return
jnp
.
broadcast_to
(
y
,
x
.
shape
)
...
@@ -216,7 +216,7 @@ def jax_funcify_Second(op):
...
@@ -216,7 +216,7 @@ def jax_funcify_Second(op):
@jax_funcify.register
(
AllocDiag
)
@jax_funcify.register
(
AllocDiag
)
def
jax_funcify_AllocDiag
(
op
):
def
jax_funcify_AllocDiag
(
op
,
**
kwargs
):
offset
=
op
.
offset
offset
=
op
.
offset
def
allocdiag
(
v
,
offset
=
offset
):
def
allocdiag
(
v
,
offset
=
offset
):
...
@@ -226,7 +226,7 @@ def jax_funcify_AllocDiag(op):
...
@@ -226,7 +226,7 @@ def jax_funcify_AllocDiag(op):
@jax_funcify.register
(
AllocEmpty
)
@jax_funcify.register
(
AllocEmpty
)
def
jax_funcify_AllocEmpty
(
op
):
def
jax_funcify_AllocEmpty
(
op
,
**
kwargs
):
def
allocempty
(
*
shape
):
def
allocempty
(
*
shape
):
return
jnp
.
empty
(
shape
,
dtype
=
op
.
dtype
)
return
jnp
.
empty
(
shape
,
dtype
=
op
.
dtype
)
...
@@ -234,7 +234,7 @@ def jax_funcify_AllocEmpty(op):
...
@@ -234,7 +234,7 @@ def jax_funcify_AllocEmpty(op):
@jax_funcify.register
(
Alloc
)
@jax_funcify.register
(
Alloc
)
def
jax_funcify_Alloc
(
op
):
def
jax_funcify_Alloc
(
op
,
**
kwargs
):
def
alloc
(
x
,
*
shape
):
def
alloc
(
x
,
*
shape
):
res
=
jnp
.
broadcast_to
(
x
,
shape
)
res
=
jnp
.
broadcast_to
(
x
,
shape
)
return
res
return
res
...
@@ -243,7 +243,7 @@ def jax_funcify_Alloc(op):
...
@@ -243,7 +243,7 @@ def jax_funcify_Alloc(op):
@jax_funcify.register
(
Dot
)
@jax_funcify.register
(
Dot
)
def
jax_funcify_Dot
(
op
):
def
jax_funcify_Dot
(
op
,
**
kwargs
):
def
dot
(
x
,
y
):
def
dot
(
x
,
y
):
return
jnp
.
dot
(
x
,
y
)
return
jnp
.
dot
(
x
,
y
)
...
@@ -251,7 +251,7 @@ def jax_funcify_Dot(op):
...
@@ -251,7 +251,7 @@ def jax_funcify_Dot(op):
@jax_funcify.register
(
ARange
)
@jax_funcify.register
(
ARange
)
def
jax_funcify_ARange
(
op
):
def
jax_funcify_ARange
(
op
,
**
kwargs
):
# XXX: This currently requires concrete arguments.
# XXX: This currently requires concrete arguments.
def
arange
(
start
,
stop
,
step
):
def
arange
(
start
,
stop
,
step
):
return
jnp
.
arange
(
start
,
stop
,
step
,
dtype
=
op
.
dtype
)
return
jnp
.
arange
(
start
,
stop
,
step
,
dtype
=
op
.
dtype
)
...
@@ -274,7 +274,7 @@ def jnp_safe_copy(x):
...
@@ -274,7 +274,7 @@ def jnp_safe_copy(x):
@jax_funcify.register
(
DeepCopyOp
)
@jax_funcify.register
(
DeepCopyOp
)
def
jax_funcify_DeepCopyOp
(
op
):
def
jax_funcify_DeepCopyOp
(
op
,
**
kwargs
):
def
deepcopyop
(
x
):
def
deepcopyop
(
x
):
return
jnp_safe_copy
(
x
)
return
jnp_safe_copy
(
x
)
...
@@ -282,7 +282,7 @@ def jax_funcify_DeepCopyOp(op):
...
@@ -282,7 +282,7 @@ def jax_funcify_DeepCopyOp(op):
@jax_funcify.register
(
Shape
)
@jax_funcify.register
(
Shape
)
def
jax_funcify_Shape
(
op
):
def
jax_funcify_Shape
(
op
,
**
kwargs
):
def
shape
(
x
):
def
shape
(
x
):
return
jnp
.
shape
(
x
)
return
jnp
.
shape
(
x
)
...
@@ -290,7 +290,7 @@ def jax_funcify_Shape(op):
...
@@ -290,7 +290,7 @@ def jax_funcify_Shape(op):
@jax_funcify.register
(
Shape_i
)
@jax_funcify.register
(
Shape_i
)
def
jax_funcify_Shape_i
(
op
):
def
jax_funcify_Shape_i
(
op
,
**
kwargs
):
i
=
op
.
i
i
=
op
.
i
def
shape_i
(
x
):
def
shape_i
(
x
):
...
@@ -300,7 +300,7 @@ def jax_funcify_Shape_i(op):
...
@@ -300,7 +300,7 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register
(
SpecifyShape
)
@jax_funcify.register
(
SpecifyShape
)
def
jax_funcify_SpecifyShape
(
op
):
def
jax_funcify_SpecifyShape
(
op
,
**
kwargs
):
def
specifyshape
(
x
,
shape
):
def
specifyshape
(
x
,
shape
):
assert
x
.
ndim
==
len
(
shape
)
assert
x
.
ndim
==
len
(
shape
)
assert
jnp
.
all
(
x
.
shape
==
tuple
(
shape
)),
(
assert
jnp
.
all
(
x
.
shape
==
tuple
(
shape
)),
(
...
@@ -315,7 +315,7 @@ def jax_funcify_SpecifyShape(op):
...
@@ -315,7 +315,7 @@ def jax_funcify_SpecifyShape(op):
@jax_funcify.register
(
Rebroadcast
)
@jax_funcify.register
(
Rebroadcast
)
def
jax_funcify_Rebroadcast
(
op
):
def
jax_funcify_Rebroadcast
(
op
,
**
kwargs
):
op_axis
=
op
.
axis
op_axis
=
op
.
axis
def
rebroadcast
(
x
):
def
rebroadcast
(
x
):
...
@@ -331,7 +331,7 @@ def jax_funcify_Rebroadcast(op):
...
@@ -331,7 +331,7 @@ def jax_funcify_Rebroadcast(op):
@jax_funcify.register
(
ViewOp
)
@jax_funcify.register
(
ViewOp
)
def
jax_funcify_ViewOp
(
op
):
def
jax_funcify_ViewOp
(
op
,
**
kwargs
):
def
viewop
(
x
):
def
viewop
(
x
):
return
x
return
x
...
@@ -339,7 +339,7 @@ def jax_funcify_ViewOp(op):
...
@@ -339,7 +339,7 @@ def jax_funcify_ViewOp(op):
@jax_funcify.register
(
Cast
)
@jax_funcify.register
(
Cast
)
def
jax_funcify_Cast
(
op
):
def
jax_funcify_Cast
(
op
,
**
kwargs
):
def
cast
(
x
):
def
cast
(
x
):
return
jnp
.
array
(
x
)
.
astype
(
op
.
o_type
.
dtype
)
return
jnp
.
array
(
x
)
.
astype
(
op
.
o_type
.
dtype
)
...
@@ -347,7 +347,7 @@ def jax_funcify_Cast(op):
...
@@ -347,7 +347,7 @@ def jax_funcify_Cast(op):
@jax_funcify.register
(
TensorFromScalar
)
@jax_funcify.register
(
TensorFromScalar
)
def
jax_funcify_TensorFromScalar
(
op
):
def
jax_funcify_TensorFromScalar
(
op
,
**
kwargs
):
def
tensor_from_scalar
(
x
):
def
tensor_from_scalar
(
x
):
return
jnp
.
array
(
x
)
return
jnp
.
array
(
x
)
...
@@ -355,7 +355,7 @@ def jax_funcify_TensorFromScalar(op):
...
@@ -355,7 +355,7 @@ def jax_funcify_TensorFromScalar(op):
@jax_funcify.register
(
ScalarFromTensor
)
@jax_funcify.register
(
ScalarFromTensor
)
def
jax_funcify_ScalarFromTensor
(
op
):
def
jax_funcify_ScalarFromTensor
(
op
,
**
kwargs
):
def
scalar_from_tensor
(
x
):
def
scalar_from_tensor
(
x
):
return
jnp
.
array
(
x
)
.
flatten
()[
0
]
return
jnp
.
array
(
x
)
.
flatten
()[
0
]
...
@@ -363,30 +363,25 @@ def jax_funcify_ScalarFromTensor(op):
...
@@ -363,30 +363,25 @@ def jax_funcify_ScalarFromTensor(op):
@jax_funcify.register
(
Elemwise
)
@jax_funcify.register
(
Elemwise
)
def
jax_funcify_Elemwise
(
op
):
def
jax_funcify_Elemwise
(
op
,
**
kwargs
):
scalar_op
=
op
.
scalar_op
scalar_op
=
op
.
scalar_op
return
jax_funcify
(
scalar_op
)
return
jax_funcify
(
scalar_op
,
**
kwargs
)
@jax_funcify.register
(
Composite
)
@jax_funcify.register
(
Composite
)
def
jax_funcify_Composite
(
op
):
def
jax_funcify_Composite
(
op
,
vectorize
=
True
,
**
kwargs
):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
jax_impl
=
jax_funcify
(
op
.
fgraph
)
jax_impl
=
jax_funcify
(
op
.
fgraph
)
def
composite
(
*
args
):
def
composite
(
*
args
):
return
jax_impl
(
*
args
)[
0
]
return
jax_impl
(
*
args
)[
0
]
return
composite
return
jnp
.
vectorize
(
composite
)
@jax_funcify.register
(
Scan
)
@jax_funcify.register
(
Scan
)
def
jax_funcify_Scan
(
op
):
def
jax_funcify_Scan
(
op
,
**
kwargs
):
inner_fg
=
FunctionGraph
(
op
.
inputs
,
op
.
outputs
)
inner_fg
=
FunctionGraph
(
op
.
inputs
,
op
.
outputs
)
jax_aet_inner_func
=
jax_funcify
(
inner_fg
)
jax_aet_inner_func
=
jax_funcify
(
inner_fg
,
**
kwargs
)
def
scan
(
*
outer_inputs
):
def
scan
(
*
outer_inputs
):
scan_args
=
ScanArgs
(
scan_args
=
ScanArgs
(
...
@@ -536,7 +531,7 @@ def jax_funcify_Scan(op):
...
@@ -536,7 +531,7 @@ def jax_funcify_Scan(op):
@jax_funcify.register
(
IfElse
)
@jax_funcify.register
(
IfElse
)
def
jax_funcify_IfElse
(
op
):
def
jax_funcify_IfElse
(
op
,
**
kwargs
):
n_outs
=
op
.
n_outs
n_outs
=
op
.
n_outs
def
ifelse
(
cond
,
*
args
,
n_outs
=
n_outs
):
def
ifelse
(
cond
,
*
args
,
n_outs
=
n_outs
):
...
@@ -549,7 +544,7 @@ def jax_funcify_IfElse(op):
...
@@ -549,7 +544,7 @@ def jax_funcify_IfElse(op):
@jax_funcify.register
(
Subtensor
)
@jax_funcify.register
(
Subtensor
)
def
jax_funcify_Subtensor
(
op
):
def
jax_funcify_Subtensor
(
op
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
...
@@ -568,7 +563,7 @@ def jax_funcify_Subtensor(op):
...
@@ -568,7 +563,7 @@ def jax_funcify_Subtensor(op):
_
=
[
jax_funcify
.
register
(
op
,
jax_funcify_Subtensor
)
for
op
in
subtensor_ops
]
_
=
[
jax_funcify
.
register
(
op
,
jax_funcify_Subtensor
)
for
op
in
subtensor_ops
]
def
jax_funcify_IncSubtensor
(
op
):
def
jax_funcify_IncSubtensor
(
op
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
...
@@ -591,7 +586,7 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
...
@@ -591,7 +586,7 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
@jax_funcify.register
(
AdvancedIncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor
)
def
jax_funcify_AdvancedIncSubtensor
(
op
):
def
jax_funcify_AdvancedIncSubtensor
(
op
,
**
kwargs
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
jax_fn
=
jax
.
ops
.
index_update
jax_fn
=
jax
.
ops
.
index_update
...
@@ -606,7 +601,12 @@ def jax_funcify_AdvancedIncSubtensor(op):
...
@@ -606,7 +601,12 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register
(
FunctionGraph
)
@jax_funcify.register
(
FunctionGraph
)
def
jax_funcify_FunctionGraph
(
def
jax_funcify_FunctionGraph
(
fgraph
,
order
=
None
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
fgraph
,
order
=
None
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
,
**
kwargs
,
):
):
if
order
is
None
:
if
order
is
None
:
...
@@ -642,7 +642,7 @@ def jax_funcify_FunctionGraph(
...
@@ -642,7 +642,7 @@ def jax_funcify_FunctionGraph(
body_assigns
=
[]
body_assigns
=
[]
for
node
in
order
:
for
node
in
order
:
jax_func
=
jax_funcify
(
node
.
op
)
jax_func
=
jax_funcify
(
node
.
op
,
node
=
node
,
**
kwargs
)
# Create a local alias with a unique name
# Create a local alias with a unique name
local_jax_func_name
=
unique_name
(
jax_func
)
local_jax_func_name
=
unique_name
(
jax_func
)
...
@@ -696,7 +696,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
...
@@ -696,7 +696,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
@jax_funcify.register
(
CAReduce
)
@jax_funcify.register
(
CAReduce
)
def
jax_funcify_CAReduce
(
op
):
def
jax_funcify_CAReduce
(
op
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
op_nfunc_spec
=
getattr
(
op
,
"nfunc_spec"
,
None
)
op_nfunc_spec
=
getattr
(
op
,
"nfunc_spec"
,
None
)
scalar_nfunc_spec
=
getattr
(
op
.
scalar_op
,
"nfunc_spec"
,
None
)
scalar_nfunc_spec
=
getattr
(
op
.
scalar_op
,
"nfunc_spec"
,
None
)
...
@@ -739,7 +739,7 @@ def jax_funcify_CAReduce(op):
...
@@ -739,7 +739,7 @@ def jax_funcify_CAReduce(op):
@jax_funcify.register
(
MakeVector
)
@jax_funcify.register
(
MakeVector
)
def
jax_funcify_MakeVector
(
op
):
def
jax_funcify_MakeVector
(
op
,
**
kwargs
):
def
makevector
(
*
x
):
def
makevector
(
*
x
):
return
jnp
.
array
(
x
,
dtype
=
op
.
dtype
)
return
jnp
.
array
(
x
,
dtype
=
op
.
dtype
)
...
@@ -747,7 +747,7 @@ def jax_funcify_MakeVector(op):
...
@@ -747,7 +747,7 @@ def jax_funcify_MakeVector(op):
@jax_funcify.register
(
Reshape
)
@jax_funcify.register
(
Reshape
)
def
jax_funcify_Reshape
(
op
):
def
jax_funcify_Reshape
(
op
,
**
kwargs
):
def
reshape
(
x
,
shape
):
def
reshape
(
x
,
shape
):
return
jnp
.
reshape
(
x
,
shape
)
return
jnp
.
reshape
(
x
,
shape
)
...
@@ -755,7 +755,7 @@ def jax_funcify_Reshape(op):
...
@@ -755,7 +755,7 @@ def jax_funcify_Reshape(op):
@jax_funcify.register
(
DimShuffle
)
@jax_funcify.register
(
DimShuffle
)
def
jax_funcify_DimShuffle
(
op
):
def
jax_funcify_DimShuffle
(
op
,
**
kwargs
):
def
dimshuffle
(
x
):
def
dimshuffle
(
x
):
res
=
jnp
.
transpose
(
x
,
op
.
shuffle
+
op
.
drop
)
res
=
jnp
.
transpose
(
x
,
op
.
shuffle
+
op
.
drop
)
...
@@ -776,7 +776,7 @@ def jax_funcify_DimShuffle(op):
...
@@ -776,7 +776,7 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register
(
Join
)
@jax_funcify.register
(
Join
)
def
jax_funcify_Join
(
op
):
def
jax_funcify_Join
(
op
,
**
kwargs
):
def
join
(
axis
,
*
tensors
):
def
join
(
axis
,
*
tensors
):
# tensors could also be tuples, and in this case they don't have a ndim
# tensors could also be tuples, and in this case they don't have a ndim
tensors
=
[
jnp
.
asarray
(
tensor
)
for
tensor
in
tensors
]
tensors
=
[
jnp
.
asarray
(
tensor
)
for
tensor
in
tensors
]
...
@@ -802,7 +802,7 @@ def jax_funcify_Join(op):
...
@@ -802,7 +802,7 @@ def jax_funcify_Join(op):
@jax_funcify.register
(
MaxAndArgmax
)
@jax_funcify.register
(
MaxAndArgmax
)
def
jax_funcify_MaxAndArgmax
(
op
):
def
jax_funcify_MaxAndArgmax
(
op
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
def
maxandargmax
(
x
,
axis
=
axis
):
def
maxandargmax
(
x
,
axis
=
axis
):
...
@@ -840,7 +840,7 @@ def jax_funcify_MaxAndArgmax(op):
...
@@ -840,7 +840,7 @@ def jax_funcify_MaxAndArgmax(op):
@jax_funcify.register
(
ExtractDiag
)
@jax_funcify.register
(
ExtractDiag
)
def
jax_funcify_ExtractDiag
(
op
):
def
jax_funcify_ExtractDiag
(
op
,
**
kwargs
):
offset
=
op
.
offset
offset
=
op
.
offset
axis1
=
op
.
axis1
axis1
=
op
.
axis1
axis2
=
op
.
axis2
axis2
=
op
.
axis2
...
@@ -852,7 +852,7 @@ def jax_funcify_ExtractDiag(op):
...
@@ -852,7 +852,7 @@ def jax_funcify_ExtractDiag(op):
@jax_funcify.register
(
Cholesky
)
@jax_funcify.register
(
Cholesky
)
def
jax_funcify_Cholesky
(
op
):
def
jax_funcify_Cholesky
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
def
cholesky
(
a
,
lower
=
lower
):
def
cholesky
(
a
,
lower
=
lower
):
...
@@ -862,7 +862,7 @@ def jax_funcify_Cholesky(op):
...
@@ -862,7 +862,7 @@ def jax_funcify_Cholesky(op):
@jax_funcify.register
(
Solve
)
@jax_funcify.register
(
Solve
)
def
jax_funcify_Solve
(
op
):
def
jax_funcify_Solve
(
op
,
**
kwargs
):
if
op
.
A_structure
==
"lower_triangular"
:
if
op
.
A_structure
==
"lower_triangular"
:
lower
=
True
lower
=
True
...
@@ -876,7 +876,7 @@ def jax_funcify_Solve(op):
...
@@ -876,7 +876,7 @@ def jax_funcify_Solve(op):
@jax_funcify.register
(
Det
)
@jax_funcify.register
(
Det
)
def
jax_funcify_Det
(
op
):
def
jax_funcify_Det
(
op
,
**
kwargs
):
def
det
(
x
):
def
det
(
x
):
return
jnp
.
linalg
.
det
(
x
)
return
jnp
.
linalg
.
det
(
x
)
...
@@ -884,7 +884,7 @@ def jax_funcify_Det(op):
...
@@ -884,7 +884,7 @@ def jax_funcify_Det(op):
@jax_funcify.register
(
Eig
)
@jax_funcify.register
(
Eig
)
def
jax_funcify_Eig
(
op
):
def
jax_funcify_Eig
(
op
,
**
kwargs
):
def
eig
(
x
):
def
eig
(
x
):
return
jnp
.
linalg
.
eig
(
x
)
return
jnp
.
linalg
.
eig
(
x
)
...
@@ -892,7 +892,7 @@ def jax_funcify_Eig(op):
...
@@ -892,7 +892,7 @@ def jax_funcify_Eig(op):
@jax_funcify.register
(
Eigh
)
@jax_funcify.register
(
Eigh
)
def
jax_funcify_Eigh
(
op
):
def
jax_funcify_Eigh
(
op
,
**
kwargs
):
uplo
=
op
.
UPLO
uplo
=
op
.
UPLO
def
eigh
(
x
,
uplo
=
uplo
):
def
eigh
(
x
,
uplo
=
uplo
):
...
@@ -902,7 +902,7 @@ def jax_funcify_Eigh(op):
...
@@ -902,7 +902,7 @@ def jax_funcify_Eigh(op):
@jax_funcify.register
(
MatrixInverse
)
@jax_funcify.register
(
MatrixInverse
)
def
jax_funcify_MatrixInverse
(
op
):
def
jax_funcify_MatrixInverse
(
op
,
**
kwargs
):
def
matrix_inverse
(
x
):
def
matrix_inverse
(
x
):
return
jnp
.
linalg
.
inv
(
x
)
return
jnp
.
linalg
.
inv
(
x
)
...
@@ -910,7 +910,7 @@ def jax_funcify_MatrixInverse(op):
...
@@ -910,7 +910,7 @@ def jax_funcify_MatrixInverse(op):
@jax_funcify.register
(
QRFull
)
@jax_funcify.register
(
QRFull
)
def
jax_funcify_QRFull
(
op
):
def
jax_funcify_QRFull
(
op
,
**
kwargs
):
mode
=
op
.
mode
mode
=
op
.
mode
def
qr_full
(
x
,
mode
=
mode
):
def
qr_full
(
x
,
mode
=
mode
):
...
@@ -920,7 +920,7 @@ def jax_funcify_QRFull(op):
...
@@ -920,7 +920,7 @@ def jax_funcify_QRFull(op):
@jax_funcify.register
(
QRIncomplete
)
@jax_funcify.register
(
QRIncomplete
)
def
jax_funcify_QRIncomplete
(
op
):
def
jax_funcify_QRIncomplete
(
op
,
**
kwargs
):
mode
=
op
.
mode
mode
=
op
.
mode
def
qr_incomplete
(
x
,
mode
=
mode
):
def
qr_incomplete
(
x
,
mode
=
mode
):
...
@@ -930,7 +930,7 @@ def jax_funcify_QRIncomplete(op):
...
@@ -930,7 +930,7 @@ def jax_funcify_QRIncomplete(op):
@jax_funcify.register
(
SVD
)
@jax_funcify.register
(
SVD
)
def
jax_funcify_SVD
(
op
):
def
jax_funcify_SVD
(
op
,
**
kwargs
):
full_matrices
=
op
.
full_matrices
full_matrices
=
op
.
full_matrices
compute_uv
=
op
.
compute_uv
compute_uv
=
op
.
compute_uv
...
@@ -941,7 +941,7 @@ def jax_funcify_SVD(op):
...
@@ -941,7 +941,7 @@ def jax_funcify_SVD(op):
@jax_funcify.register
(
CumOp
)
@jax_funcify.register
(
CumOp
)
def
jax_funcify_CumOp
(
op
):
def
jax_funcify_CumOp
(
op
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
mode
=
op
.
mode
mode
=
op
.
mode
...
@@ -955,7 +955,7 @@ def jax_funcify_CumOp(op):
...
@@ -955,7 +955,7 @@ def jax_funcify_CumOp(op):
@jax_funcify.register
(
DiffOp
)
@jax_funcify.register
(
DiffOp
)
def
jax_funcify_DiffOp
(
op
):
def
jax_funcify_DiffOp
(
op
,
**
kwargs
):
n
=
op
.
n
n
=
op
.
n
axis
=
op
.
axis
axis
=
op
.
axis
...
@@ -966,7 +966,7 @@ def jax_funcify_DiffOp(op):
...
@@ -966,7 +966,7 @@ def jax_funcify_DiffOp(op):
@jax_funcify.register
(
RepeatOp
)
@jax_funcify.register
(
RepeatOp
)
def
jax_funcify_RepeatOp
(
op
):
def
jax_funcify_RepeatOp
(
op
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
def
repeatop
(
x
,
repeats
,
axis
=
axis
):
def
repeatop
(
x
,
repeats
,
axis
=
axis
):
...
@@ -976,7 +976,7 @@ def jax_funcify_RepeatOp(op):
...
@@ -976,7 +976,7 @@ def jax_funcify_RepeatOp(op):
@jax_funcify.register
(
Bartlett
)
@jax_funcify.register
(
Bartlett
)
def
jax_funcify_Bartlett
(
op
):
def
jax_funcify_Bartlett
(
op
,
**
kwargs
):
def
bartlett
(
x
):
def
bartlett
(
x
):
return
jnp
.
bartlett
(
x
)
return
jnp
.
bartlett
(
x
)
...
@@ -984,7 +984,7 @@ def jax_funcify_Bartlett(op):
...
@@ -984,7 +984,7 @@ def jax_funcify_Bartlett(op):
@jax_funcify.register
(
FillDiagonal
)
@jax_funcify.register
(
FillDiagonal
)
def
jax_funcify_FillDiagonal
(
op
):
def
jax_funcify_FillDiagonal
(
op
,
**
kwargs
):
# def filldiagonal(a, val):
# def filldiagonal(a, val):
# if a.ndim == 2:
# if a.ndim == 2:
...
@@ -1002,7 +1002,7 @@ def jax_funcify_FillDiagonal(op):
...
@@ -1002,7 +1002,7 @@ def jax_funcify_FillDiagonal(op):
@jax_funcify.register
(
FillDiagonalOffset
)
@jax_funcify.register
(
FillDiagonalOffset
)
def
jax_funcify_FillDiagonalOffset
(
op
):
def
jax_funcify_FillDiagonalOffset
(
op
,
**
kwargs
):
# def filldiagonaloffset(a, val, offset):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
# height, width = a.shape
...
@@ -1026,7 +1026,7 @@ def jax_funcify_FillDiagonalOffset(op):
...
@@ -1026,7 +1026,7 @@ def jax_funcify_FillDiagonalOffset(op):
@jax_funcify.register
(
Unique
)
@jax_funcify.register
(
Unique
)
def
jax_funcify_Unique
(
op
):
def
jax_funcify_Unique
(
op
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
if
axis
is
not
None
:
if
axis
is
not
None
:
...
@@ -1055,7 +1055,7 @@ def jax_funcify_Unique(op):
...
@@ -1055,7 +1055,7 @@ def jax_funcify_Unique(op):
@jax_funcify.register
(
UnravelIndex
)
@jax_funcify.register
(
UnravelIndex
)
def
jax_funcify_UnravelIndex
(
op
):
def
jax_funcify_UnravelIndex
(
op
,
**
kwargs
):
order
=
op
.
order
order
=
op
.
order
warn
(
"JAX ignores the `order` parameter in `unravel_index`."
)
warn
(
"JAX ignores the `order` parameter in `unravel_index`."
)
...
@@ -1067,7 +1067,7 @@ def jax_funcify_UnravelIndex(op):
...
@@ -1067,7 +1067,7 @@ def jax_funcify_UnravelIndex(op):
@jax_funcify.register
(
RavelMultiIndex
)
@jax_funcify.register
(
RavelMultiIndex
)
def
jax_funcify_RavelMultiIndex
(
op
):
def
jax_funcify_RavelMultiIndex
(
op
,
**
kwargs
):
mode
=
op
.
mode
mode
=
op
.
mode
order
=
op
.
order
order
=
op
.
order
...
@@ -1079,7 +1079,7 @@ def jax_funcify_RavelMultiIndex(op):
...
@@ -1079,7 +1079,7 @@ def jax_funcify_RavelMultiIndex(op):
@jax_funcify.register
(
Eye
)
@jax_funcify.register
(
Eye
)
def
jax_funcify_Eye
(
op
):
def
jax_funcify_Eye
(
op
,
**
kwargs
):
dtype
=
op
.
dtype
dtype
=
op
.
dtype
def
eye
(
N
,
M
,
k
):
def
eye
(
N
,
M
,
k
):
...
@@ -1089,7 +1089,7 @@ def jax_funcify_Eye(op):
...
@@ -1089,7 +1089,7 @@ def jax_funcify_Eye(op):
@jax_funcify.register
(
BatchedDot
)
@jax_funcify.register
(
BatchedDot
)
def
jax_funcify_BatchedDot
(
op
):
def
jax_funcify_BatchedDot
(
op
,
**
kwargs
):
def
batched_dot
(
a
,
b
):
def
batched_dot
(
a
,
b
):
if
a
.
shape
[
0
]
!=
b
.
shape
[
0
]:
if
a
.
shape
[
0
]
!=
b
.
shape
[
0
]:
raise
TypeError
(
"Shapes must match in the 0-th dimension"
)
raise
TypeError
(
"Shapes must match in the 0-th dimension"
)
...
@@ -1101,7 +1101,7 @@ def jax_funcify_BatchedDot(op):
...
@@ -1101,7 +1101,7 @@ def jax_funcify_BatchedDot(op):
@jax_funcify.register
(
RandomVariable
)
@jax_funcify.register
(
RandomVariable
)
def
jax_funcify_RandomVariable
(
op
):
def
jax_funcify_RandomVariable
(
op
,
**
kwargs
):
name
=
op
.
name
name
=
op
.
name
if
not
hasattr
(
jax
.
random
,
name
):
if
not
hasattr
(
jax
.
random
,
name
):
...
...
tests/link/test_jax.py
浏览文件 @
e6914913
...
@@ -298,22 +298,32 @@ def test_jax_basic():
...
@@ -298,22 +298,32 @@ def test_jax_basic():
)
)
def
test_jax_Composite
():
@pytest.mark.parametrize
(
"x, y, x_val, y_val"
,
[
(
scalar
(
"x"
),
scalar
(
"y"
),
np
.
array
(
10
),
np
.
array
(
20
)),
(
scalar
(
"x"
),
vector
(
"y"
),
np
.
array
(
10
),
np
.
arange
(
10
,
20
)),
(
matrix
(
"x"
),
vector
(
"y"
),
np
.
arange
(
10
*
20
)
.
reshape
((
20
,
10
)),
np
.
arange
(
10
,
20
),
),
],
)
def
test_jax_Composite
(
x
,
y
,
x_val
,
y_val
):
x_s
=
aes
.
float64
(
"x"
)
x_s
=
aes
.
float64
(
"x"
)
y_s
=
aes
.
float64
(
"y"
)
y_s
=
aes
.
float64
(
"y"
)
comp_op
=
Elemwise
(
Composite
([
x_s
,
y_s
],
[
x_s
+
y_s
*
2
]))
comp_op
=
Elemwise
(
Composite
([
x_s
,
y_s
],
[
x_s
+
y_s
*
2
+
aes
.
exp
(
x_s
-
y_s
)]))
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
out
=
comp_op
(
x
,
y
)
out
=
comp_op
(
x
,
y
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
test_input_vals
=
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
x_val
.
astype
(
config
.
floatX
),
np
.
arange
(
10
,
20
)
.
astype
(
config
.
floatX
),
y_val
.
astype
(
config
.
floatX
),
]
]
_
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
_
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
...
@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once():
...
@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once():
outputs
[
i
][
0
]
=
inp
[
0
]
outputs
[
i
][
0
]
=
inp
[
0
]
@jax_funcify.register
(
TestOp
)
@jax_funcify.register
(
TestOp
)
def
jax_funcify_TestOp
(
op
):
def
jax_funcify_TestOp
(
op
,
**
kwargs
):
def
func
(
*
args
,
op
=
op
):
def
func
(
*
args
,
op
=
op
):
op
.
called
+=
1
op
.
called
+=
1
return
list
(
args
)
return
list
(
args
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论