Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d6113953
提交
d6113953
authored
8月 25, 2023
作者:
Ricardo Vieira
提交者:
Thomas Wiecki
9月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement vectorize utility
上级
a3eed0b4
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
111 行增加
和
55 行删除
+111
-55
__init__.py
pytensor/graph/__init__.py
+1
-1
replace.py
pytensor/graph/replace.py
+65
-2
loop.py
pytensor/scalar/loop.py
+2
-1
blockwise.py
pytensor/tensor/blockwise.py
+14
-44
elemwise.py
pytensor/tensor/elemwise.py
+2
-1
op.py
pytensor/tensor/random/op.py
+1
-1
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+2
-1
test_replace.py
tests/graph/test_replace.py
+20
-1
test_op.py
tests/tensor/random/test_op.py
+1
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+2
-1
test_elemwise.py
tests/tensor/test_elemwise.py
+1
-1
没有找到文件。
pytensor/graph/__init__.py
浏览文件 @
d6113953
...
...
@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone
,
ancestors
,
)
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
,
vectorize
from
pytensor.graph.op
import
Op
from
pytensor.graph.type
import
Type
from
pytensor.graph.fg
import
FunctionGraph
...
...
pytensor/graph/replace.py
浏览文件 @
d6113953
from
functools
import
partial
from
typing
import
Iterable
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
functools
import
partial
,
singledispatch
from
typing
import
Iterable
,
Mapping
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
truncated_graph_inputs
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
ReplaceTypes
=
Union
[
Iterable
[
tuple
[
Variable
,
Variable
]],
dict
[
Variable
,
Variable
]]
...
...
@@ -198,3 +199,65 @@ def graph_replace(
return
list
(
fg
.
outputs
)
else
:
return
fg
.
outputs
[
0
]
@singledispatch
def
_vectorize_node
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
# Default implementation is provided in pytensor.tensor.blockwise
raise
NotImplementedError
def
vectorize_node
(
node
:
Apply
,
*
batched_inputs
)
->
Apply
:
"""Returns vectorized version of node with new batched inputs."""
op
=
node
.
op
return
_vectorize_node
(
op
,
node
,
*
batched_inputs
)
def
vectorize
(
outputs
:
Sequence
[
Variable
],
vectorize
:
Mapping
[
Variable
,
Variable
]
)
->
Sequence
[
Variable
]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
Examples
--------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])
"""
# Avoid circular import
inputs
=
truncated_graph_inputs
(
outputs
,
ancestors_to_include
=
vectorize
.
keys
())
new_inputs
=
[
vectorize
.
get
(
inp
,
inp
)
for
inp
in
inputs
]
def
transform
(
var
):
if
var
in
inputs
:
return
new_inputs
[
inputs
.
index
(
var
)]
node
=
var
.
owner
batched_inputs
=
[
transform
(
inp
)
for
inp
in
node
.
inputs
]
batched_node
=
vectorize_node
(
node
,
*
batched_inputs
)
batched_var
=
batched_node
.
outputs
[
var
.
owner
.
outputs
.
index
(
var
)]
return
batched_var
# TODO: MergeOptimization or node caching?
return
[
transform
(
out
)
for
out
in
outputs
]
pytensor/scalar/loop.py
浏览文件 @
d6113953
...
...
@@ -2,7 +2,8 @@ from itertools import chain
from
typing
import
Optional
,
Sequence
,
Tuple
from
pytensor.compile
import
rebuild_collect_shared
from
pytensor.graph
import
Constant
,
FunctionGraph
,
Variable
,
clone
from
pytensor.graph.basic
import
Constant
,
Variable
,
clone
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar.basic
import
ScalarInnerGraphOp
,
as_scalar
...
...
pytensor/tensor/blockwise.py
浏览文件 @
d6113953
import
re
from
functools
import
singledispatch
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
cast
import
numpy
as
np
...
...
@@ -9,6 +8,7 @@ from pytensor.gradient import DisconnectedType
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
,
vectorize
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor.shape
import
shape_padleft
from
pytensor.tensor.type
import
continuous_dtypes
,
discrete_dtypes
,
tensor
...
...
@@ -72,8 +72,8 @@ def safe_signature(
return
f
"{inputs_sig}->{outputs_sig}"
@
singledispatch
def
_vectorize_node
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
@
_vectorize_node.register
(
Op
)
def
vectorize_node_fallback
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
if
hasattr
(
op
,
"gufunc_signature"
):
signature
=
op
.
gufunc_signature
else
:
...
...
@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
return
cast
(
Apply
,
Blockwise
(
op
,
signature
=
signature
)
.
make_node
(
*
bached_inputs
))
def
vectorize_node
(
node
:
Apply
,
*
batched_inputs
)
->
Apply
:
"""Returns vectorized version of node with new batched inputs."""
op
=
node
.
op
return
_vectorize_node
(
op
,
node
,
*
batched_inputs
)
class
Blockwise
(
Op
):
"""Generalizes a core `Op` to work with batched dimensions.
...
...
@@ -279,42 +273,18 @@ class Blockwise(Op):
core_igrads
=
self
.
core_op
.
L_op
(
core_inputs
,
core_outputs
,
core_ograds
)
batch_ndims
=
self
.
_batch_ndim_from_outputs
(
outputs
)
def
transform
(
var
):
# From a graph of ScalarOps, make a graph of Broadcast ops.
if
isinstance
(
var
.
type
,
(
NullType
,
DisconnectedType
)):
return
var
if
var
in
core_inputs
:
return
inputs
[
core_inputs
.
index
(
var
)]
if
var
in
core_outputs
:
return
outputs
[
core_outputs
.
index
(
var
)]
if
var
in
core_ograds
:
return
ograds
[
core_ograds
.
index
(
var
)]
node
=
var
.
owner
# The gradient contains a constant, which may be responsible for broadcasting
if
node
is
None
:
if
batch_ndims
:
var
=
shape_padleft
(
var
,
batch_ndims
)
return
var
batched_inputs
=
[
transform
(
inp
)
for
inp
in
node
.
inputs
]
batched_node
=
vectorize_node
(
node
,
*
batched_inputs
)
batched_var
=
batched_node
.
outputs
[
var
.
owner
.
outputs
.
index
(
var
)]
return
batched_var
ret
=
[]
for
core_igrad
,
ipt
in
zip
(
core_igrads
,
inputs
):
# Undefined gradient
if
core_igrad
is
None
:
ret
.
append
(
None
)
else
:
ret
.
append
(
transform
(
core_igrad
))
igrads
=
vectorize
(
[
core_igrad
for
core_igrad
in
core_igrads
if
core_igrad
is
not
None
],
vectorize
=
dict
(
zip
(
core_inputs
+
core_outputs
+
core_ograds
,
inputs
+
outputs
+
ograds
)
),
)
return
ret
igrads_iter
=
iter
(
igrads
)
return
[
None
if
core_igrad
is
None
else
next
(
igrads_iter
)
for
core_igrad
in
core_igrads
]
def
L_op
(
self
,
inputs
,
outs
,
ograds
):
from
pytensor.tensor.math
import
sum
as
pt_sum
...
...
pytensor/tensor/elemwise.py
浏览文件 @
d6113953
...
...
@@ -8,6 +8,7 @@ from pytensor.configdefaults import config
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.link.c.basic
import
failure_code
from
pytensor.link.c.op
import
COp
,
ExternalCOp
,
OpenMPOp
...
...
@@ -22,7 +23,7 @@ from pytensor.scalar.basic import transfer_type, upcast
from
pytensor.tensor
import
elemwise_cgen
as
cgen
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.basic
import
_get_vector_length
,
as_tensor_variable
from
pytensor.tensor.blockwise
import
_vectorize_node
,
vectorize_not_needed
from
pytensor.tensor.blockwise
import
vectorize_not_needed
from
pytensor.tensor.type
import
(
TensorType
,
continuous_dtypes
,
...
...
pytensor/tensor/random/op.py
浏览文件 @
d6113953
...
...
@@ -7,6 +7,7 @@ import pytensor
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.scalar
import
ScalarVariable
from
pytensor.tensor.basic
import
(
...
...
@@ -17,7 +18,6 @@ from pytensor.tensor.basic import (
get_vector_length
,
infer_static_shape
,
)
from
pytensor.tensor.blockwise
import
_vectorize_node
from
pytensor.tensor.random.type
import
RandomGeneratorType
,
RandomStateType
,
RandomType
from
pytensor.tensor.random.utils
import
(
broadcast_params
,
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
d6113953
from
pytensor.compile.mode
import
optdb
from
pytensor.graph
import
node_rewriter
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node
from
pytensor.tensor.blockwise
import
Blockwise
@node_rewriter
([
Blockwise
])
...
...
tests/graph/test_replace.py
浏览文件 @
d6113953
import
numpy
as
np
import
pytest
import
scipy.special
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
function
,
shared
from
pytensor.graph.basic
import
graph_inputs
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
from
pytensor.graph.replace
import
clone_replace
,
graph_replace
,
vectorize
from
pytensor.tensor
import
dvector
,
fvector
,
vector
from
tests
import
unittest_tools
as
utt
from
tests.graph.utils
import
MyOp
,
MyVariable
...
...
@@ -223,3 +224,21 @@ class TestGraphReplace:
assert
oc
[
0
]
is
o
with
pytest
.
raises
(
ValueError
,
match
=
"Some replacements were not used"
):
oc
=
graph_replace
([
o
],
{
fake
:
x
.
clone
()},
strict
=
True
)
class
TestVectorize
:
# TODO: Add tests with multiple outputs, constants, and other singleton types
def
test_basic
(
self
):
x
=
pt
.
vector
(
"x"
)
y
=
pt
.
exp
(
x
)
/
pt
.
sum
(
pt
.
exp
(
x
))
new_x
=
pt
.
matrix
(
"new_x"
)
[
new_y
]
=
vectorize
([
y
],
{
x
:
new_x
})
fn
=
function
([
new_x
],
new_y
)
test_new_y
=
np
.
array
([[
0
,
1
,
2
],
[
2
,
1
,
0
]])
.
astype
(
config
.
floatX
)
np
.
testing
.
assert_allclose
(
fn
(
test_new_y
),
scipy
.
special
.
softmax
(
test_new_y
,
axis
=-
1
),
)
tests/tensor/random/test_op.py
浏览文件 @
d6113953
...
...
@@ -4,8 +4,8 @@ import pytest
import
pytensor.tensor
as
at
from
pytensor
import
config
,
function
from
pytensor.gradient
import
NullTypeGradError
,
grad
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.raise_op
import
Assert
from
pytensor.tensor.blockwise
import
vectorize_node
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.op
import
RandomState
,
RandomVariable
,
default_rng
...
...
tests/tensor/test_blockwise.py
浏览文件 @
d6113953
...
...
@@ -8,8 +8,9 @@ import pytensor
from
pytensor
import
config
from
pytensor.gradient
import
grad
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.tensor
import
tensor
from
pytensor.tensor.blockwise
import
Blockwise
,
_parse_gufunc_signature
,
vectorize_node
from
pytensor.tensor.blockwise
import
Blockwise
,
_parse_gufunc_signature
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
...
...
tests/tensor/test_elemwise.py
浏览文件 @
d6113953
...
...
@@ -13,11 +13,11 @@ from pytensor.compile.mode import Mode
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.link.basic
import
PerformLinker
from
pytensor.link.c.basic
import
CLinker
,
OpWiseCLinker
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor.basic
import
second
from
pytensor.tensor.blockwise
import
vectorize_node
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
Any
,
Sum
from
pytensor.tensor.math
import
all
as
pt_all
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论