Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
35f0df96
提交
35f0df96
authored
11月 22, 2023
作者:
Ricardo Vieira
提交者:
Thomas Wiecki
11月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make params exclusive to COp's
Also removes them from the signature of perform
上级
5f840278
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
105 行增加
和
132 行删除
+105
-132
using_params.rst
doc/extending/using_params.rst
+17
-21
basic.py
pytensor/graph/basic.py
+0
-11
op.py
pytensor/graph/op.py
+9
-56
basic.py
pytensor/link/c/basic.py
+11
-2
interface.py
pytensor/link/c/interface.py
+24
-1
op.py
pytensor/link/c/op.py
+2
-2
basic.py
pytensor/link/numba/dispatch/basic.py
+5
-16
raise_op.py
pytensor/raise_op.py
+1
-1
op.py
pytensor/scan/op.py
+1
-1
basic.py
pytensor/tensor/basic.py
+3
-3
blas.py
pytensor/tensor/blas.py
+4
-4
elemwise.py
pytensor/tensor/elemwise.py
+1
-1
extra_ops.py
pytensor/tensor/extra_ops.py
+3
-3
math.py
pytensor/tensor/math.py
+3
-3
uncanonicalize.py
pytensor/tensor/rewriting/uncanonicalize.py
+1
-1
shape.py
pytensor/tensor/shape.py
+2
-2
subtensor.py
pytensor/tensor/subtensor.py
+1
-1
test_params_type.py
tests/link/c/test_params_type.py
+4
-2
test_type.py
tests/link/c/test_type.py
+2
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+11
-0
没有找到文件。
doc/extending/using_params.rst
浏览文件 @
35f0df96
.. _extending_op_params:
.. _extending_op_params:
===============
===============
=
Using Op params
Using
C
Op params
===============
===============
=
The Op params is a facility to pass some runtime parameters to the
The
C
Op params is a facility to pass some runtime parameters to the
code of an op without modifying it. It can enable a single instance
code of an op without modifying it. It can enable a single instance
of C code to serve different needs and therefore reduce compilation.
of C code to serve different needs and therefore reduce compilation.
...
@@ -53,7 +53,7 @@ following methods will be used for the type:
...
@@ -53,7 +53,7 @@ following methods will be used for the type:
- :meth:`__hash__ <Type.__hash__>`
- :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <Type.values_eq>`
- :meth:`values_eq <Type.values_eq>`
Additionally
if you want
to use your params with C code, you need to extend `COp`
Additionally
,
to use your params with C code, you need to extend `COp`
and implement the following methods:
and implement the following methods:
- :meth:`c_declare <CLinkerType.c_declare>`
- :meth:`c_declare <CLinkerType.c_declare>`
...
@@ -65,24 +65,24 @@ You can also define other convenience methods such as
...
@@ -65,24 +65,24 @@ You can also define other convenience methods such as
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
Registering the params with your Op
Registering the params with your
C
Op
-----------------------------------
-----------------------------------
-
To declare that your
Op
uses params you have to set the class
To declare that your
`COp`
uses params you have to set the class
attribute :attr:`params_type` to an instance of your params Type.
attribute :attr:`params_type` to an instance of your params Type.
.. note::
.. note::
If you want to have multiple parameters, PyTensor provides the convenient class
If you want to have multiple parameters, PyTensor provides the convenient class
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available
in both Python (as a Python object) and
C code (as a struct).
one object that will be available
to the
C code (as a struct).
For example if we decide to use an int as the params the following
For example if we decide to use an int as the params the following
would be appropriate:
would be appropriate:
.. code-block:: python
.. code-block:: python
class MyOp(Op):
class MyOp(
C
Op):
params_type = Generic()
params_type = Generic()
After that you need to define a :meth:`get_params` method on your
After that you need to define a :meth:`get_params` method on your
...
@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
...
@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
signature of :meth:`perform`. The new expected signature will have an
signature of :meth:`perform`. The new expected signature will have an
extra parameter at the end which corresponds to the params object.
extra parameter at the end which corresponds to the params object.
.. warning::
The `sub` dictionary for `COp`s with params will contain an extra entry
If you do not account for this extra parameter, the code will fail
at runtime if it tries to run the python version.
Also, for the C code, the `sub` dictionary will contain an extra entry
`'params'` which will map to the variable name of the params object.
`'params'` which will map to the variable name of the params object.
This is true for all methods that receive a `sub` parameter, so this
This is true for all methods that receive a `sub` parameter, so this
means that you can use your params in the :meth:`c_code <COp.c_code>`
means that you can use your params in the :meth:`c_code <COp.c_code>`
...
@@ -131,7 +126,7 @@ A simple example
...
@@ -131,7 +126,7 @@ A simple example
----------------
----------------
This is a simple example which uses a params object to pass a value.
This is a simple example which uses a params object to pass a value.
This `Op` will multiply a scalar input by a fixed floating point value.
This `
C
Op` will multiply a scalar input by a fixed floating point value.
Since the value in this case is a python float, we chose Generic as
Since the value in this case is a python float, we chose Generic as
the params type.
the params type.
...
@@ -156,9 +151,10 @@ the params type.
...
@@ -156,9 +151,10 @@ the params type.
inp = as_scalar(inp)
inp = as_scalar(inp)
return Apply(self, [inp], [inp.type()])
return Apply(self, [inp], [inp.type()])
def perform(self, node, inputs, output_storage, params):
def perform(self, node, inputs, output_storage):
# Here params is a python float so this is ok
# Because params is a python float we can use `self.mul` directly.
output_storage[0][0] = inputs[0] * params
# If it's something fancier, call `self.params_type.filter(self.get_params(node))`
output_storage[0][0] = inputs[0] * self.mul
def c_code(self, node, name, inputs, outputs, sub):
def c_code(self, node, name, inputs, outputs, sub):
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
...
@@ -174,7 +170,7 @@ weights.
...
@@ -174,7 +170,7 @@ weights.
.. testcode::
.. testcode::
from pytensor.
graph.op import
Op
from pytensor.
link.c.op import C
Op
from pytensor.link.c.type import Generic
from pytensor.link.c.type import Generic
from pytensor.scalar import as_scalar
from pytensor.scalar import as_scalar
...
...
pytensor/graph/basic.py
浏览文件 @
35f0df96
...
@@ -30,7 +30,6 @@ import numpy as np
...
@@ -30,7 +30,6 @@ import numpy as np
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.utils
import
(
from
pytensor.graph.utils
import
(
MetaObject
,
MetaObject
,
MethodNotDefined
,
Scratchpad
,
Scratchpad
,
TestValueError
,
TestValueError
,
ValidatingScratchpad
,
ValidatingScratchpad
,
...
@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]):
...
@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]):
f
"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
f
"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)
)
def
run_params
(
self
):
"""
Returns the params for the node, or NoParams if no params is set.
"""
try
:
return
self
.
op
.
get_params
(
self
)
except
MethodNotDefined
:
return
NoParams
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
self
.
__dict__
d
=
self
.
__dict__
# ufunc don't pickle/unpickle well
# ufunc don't pickle/unpickle well
...
...
pytensor/graph/op.py
浏览文件 @
35f0df96
...
@@ -16,15 +16,13 @@ from typing import (
...
@@ -16,15 +16,13 @@ from typing import (
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
NoParams
,
Variable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.utils
import
(
from
pytensor.graph.utils
import
(
MetaObject
,
MetaObject
,
MethodNotDefined
,
TestValueError
,
TestValueError
,
add_tag_trace
,
add_tag_trace
,
get_variable_trace_string
,
get_variable_trace_string
,
)
)
from
pytensor.link.c.params_type
import
Params
,
ParamsType
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType]
...
@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType]
ComputeMapType
=
dict
[
Variable
,
list
[
bool
]]
ComputeMapType
=
dict
[
Variable
,
list
[
bool
]]
InputStorageType
=
list
[
StorageCellType
]
InputStorageType
=
list
[
StorageCellType
]
OutputStorageType
=
list
[
StorageCellType
]
OutputStorageType
=
list
[
StorageCellType
]
ParamsInputType
=
Optional
[
tuple
[
Any
,
...
]]
PerformMethodType
=
Callable
[[
Apply
,
list
[
Any
],
OutputStorageType
],
None
]
PerformMethodType
=
Callable
[
[
Apply
,
list
[
Any
],
OutputStorageType
,
ParamsInputType
],
None
]
BasicThunkType
=
Callable
[[],
None
]
BasicThunkType
=
Callable
[[],
None
]
ThunkCallableType
=
Callable
[
ThunkCallableType
=
Callable
[
[
PerformMethodType
,
StorageMapType
,
ComputeMapType
,
Apply
],
None
[
PerformMethodType
,
StorageMapType
,
ComputeMapType
,
Apply
],
None
...
@@ -202,7 +197,6 @@ class Op(MetaObject):
...
@@ -202,7 +197,6 @@ class Op(MetaObject):
itypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
itypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
otypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
otypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
params_type
:
Optional
[
ParamsType
]
=
None
_output_type_depends_on_input_value
=
False
_output_type_depends_on_input_value
=
False
"""
"""
...
@@ -426,7 +420,6 @@ class Op(MetaObject):
...
@@ -426,7 +420,6 @@ class Op(MetaObject):
node
:
Apply
,
node
:
Apply
,
inputs
:
Sequence
[
Any
],
inputs
:
Sequence
[
Any
],
output_storage
:
OutputStorageType
,
output_storage
:
OutputStorageType
,
params
:
ParamsInputType
=
None
,
)
->
None
:
)
->
None
:
"""Calculate the function on the inputs and put the variables in the output storage.
"""Calculate the function on the inputs and put the variables in the output storage.
...
@@ -442,8 +435,6 @@ class Op(MetaObject):
...
@@ -442,8 +435,6 @@ class Op(MetaObject):
these lists). Each sub-list corresponds to value of each
these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method
`Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists.
is to set the values of these sub-lists.
params
A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes
Notes
-----
-----
...
@@ -481,22 +472,6 @@ class Op(MetaObject):
...
@@ -481,22 +472,6 @@ class Op(MetaObject):
"""
"""
return
True
return
True
def
get_params
(
self
,
node
:
Apply
)
->
Params
:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if
isinstance
(
self
.
params_type
,
ParamsType
):
wrapper
=
self
.
params_type
if
not
all
(
hasattr
(
self
,
field
)
for
field
in
wrapper
.
fields
):
# Let's print missing attributes for debugging.
not_found
=
tuple
(
field
for
field
in
wrapper
.
fields
if
not
hasattr
(
self
,
field
)
)
raise
AttributeError
(
f
"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return
self
.
params_type
.
get_params
(
self
)
raise
MethodNotDefined
(
"get_params"
)
def
prepare_node
(
def
prepare_node
(
self
,
self
,
node
:
Apply
,
node
:
Apply
,
...
@@ -538,34 +513,12 @@ class Op(MetaObject):
...
@@ -538,34 +513,12 @@ class Op(MetaObject):
else
:
else
:
p
=
node
.
op
.
perform
p
=
node
.
op
.
perform
params
=
node
.
run_params
()
@is_thunk_type
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
):
if
params
is
NoParams
:
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
# default arguments are stored in the closure of `rval`
for
o
in
node
.
outputs
:
@is_thunk_type
compute_map
[
o
][
0
]
=
True
def
rval
(
return
r
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
params
=
None
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
return
r
else
:
params_val
=
node
.
params_type
.
filter
(
params
)
@is_thunk_type
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
params
=
params_val
,
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
,
params
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
return
r
rval
.
inputs
=
node_input_storage
rval
.
inputs
=
node_input_storage
rval
.
outputs
=
node_output_storage
rval
.
outputs
=
node_output_storage
...
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
...
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
"""
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
"No Python implementation is provided by this Op."
)
raise
NotImplementedError
(
"No Python implementation is provided by this Op."
)
...
...
pytensor/link/c/basic.py
浏览文件 @
35f0df96
...
@@ -20,6 +20,7 @@ from pytensor.graph.basic import (
...
@@ -20,6 +20,7 @@ from pytensor.graph.basic import (
io_toposort
,
io_toposort
,
vars_between
,
vars_between
,
)
)
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.link.basic
import
Container
,
Linker
,
LocalLinker
,
PerformLinker
from
pytensor.link.basic
import
Container
,
Linker
,
LocalLinker
,
PerformLinker
from
pytensor.link.c.cmodule
import
(
from
pytensor.link.c.cmodule
import
(
METH_VARARGS
,
METH_VARARGS
,
...
@@ -617,7 +618,12 @@ class CLinker(Linker):
...
@@ -617,7 +618,12 @@ class CLinker(Linker):
# that needs it
# that needs it
self
.
node_params
=
dict
()
self
.
node_params
=
dict
()
for
node
in
self
.
node_order
:
for
node
in
self
.
node_order
:
params
=
node
.
run_params
()
if
not
isinstance
(
node
.
op
,
CLinkerOp
):
continue
try
:
params
=
node
.
op
.
get_params
(
node
)
except
MethodNotDefined
:
params
=
NoParams
if
params
is
not
NoParams
:
if
params
is
not
NoParams
:
# try to avoid creating more than one variable for the
# try to avoid creating more than one variable for the
# same params.
# same params.
...
@@ -803,7 +809,10 @@ class CLinker(Linker):
...
@@ -803,7 +809,10 @@ class CLinker(Linker):
sub
=
dict
(
failure_var
=
failure_var
)
sub
=
dict
(
failure_var
=
failure_var
)
params
=
node
.
run_params
()
try
:
params
=
op
.
get_params
(
node
)
except
MethodNotDefined
:
params
=
NoParams
if
params
is
not
NoParams
:
if
params
is
not
NoParams
:
params_var
=
symbol
[
self
.
node_params
[
params
]]
params_var
=
symbol
[
self
.
node_params
[
params
]]
...
...
pytensor/link/c/interface.py
浏览文件 @
35f0df96
import
typing
import
warnings
import
warnings
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Callable
from
typing
import
Callable
,
Optional
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.graph.utils
import
MethodNotDefined
if
typing
.
TYPE_CHECKING
:
from
pytensor.link.c.params_type
import
Params
,
ParamsType
class
CLinkerObject
:
class
CLinkerObject
:
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""
...
@@ -172,6 +177,8 @@ class CLinkerObject:
...
@@ -172,6 +177,8 @@ class CLinkerObject:
class
CLinkerOp
(
CLinkerObject
):
class
CLinkerOp
(
CLinkerObject
):
"""Interface definition for `Op` subclasses compiled by `CLinker`."""
"""Interface definition for `Op` subclasses compiled by `CLinker`."""
params_type
:
Optional
[
"ParamsType"
]
=
None
@abstractmethod
@abstractmethod
def
c_code
(
def
c_code
(
self
,
self
,
...
@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject):
...
@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject):
"""
"""
return
""
return
""
def
get_params
(
self
,
node
:
Apply
)
->
"Params"
:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if
self
.
params_type
is
not
None
:
wrapper
=
self
.
params_type
if
not
all
(
hasattr
(
self
,
field
)
for
field
in
wrapper
.
fields
):
# Let's print missing attributes for debugging.
not_found
=
tuple
(
field
for
field
in
wrapper
.
fields
if
not
hasattr
(
self
,
field
)
)
raise
AttributeError
(
f
"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return
self
.
params_type
.
get_params
(
self
)
raise
MethodNotDefined
(
"get_params"
)
class
CLinkerType
(
CLinkerObject
):
class
CLinkerType
(
CLinkerObject
):
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
...
...
pytensor/link/c/op.py
浏览文件 @
35f0df96
...
@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
...
@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
"""
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
"No Python implementation is provided by this COp."
)
raise
NotImplementedError
(
"No Python implementation is provided by this COp."
)
...
@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
...
@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
"""
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"No Python implementation is provided by this ExternalCOp."
"No Python implementation is provided by this ExternalCOp."
)
)
pytensor/link/numba/dispatch/basic.py
浏览文件 @
35f0df96
...
@@ -21,7 +21,7 @@ from numba.extending import box, overload
...
@@ -21,7 +21,7 @@ from numba.extending import box, overload
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.graph.basic
import
Apply
,
NoParams
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.type
import
Type
from
pytensor.graph.type
import
Type
from
pytensor.ifelse
import
IfElse
from
pytensor.ifelse
import
IfElse
...
@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
...
@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
output_types
=
tuple
(
out
.
type
for
out
in
node
.
outputs
)
output_types
=
tuple
(
out
.
type
for
out
in
node
.
outputs
)
params
=
node
.
run_params
()
if
params
is
not
NoParams
:
def
py_perform
(
inputs
):
params_val
=
dict
(
node
.
params_type
.
filter
(
params
))
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
)
def
py_perform
(
inputs
):
return
outputs
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
,
params_val
)
return
outputs
else
:
def
py_perform
(
inputs
):
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
)
return
outputs
if
n_outputs
==
1
:
if
n_outputs
==
1
:
...
...
pytensor/raise_op.py
浏览文件 @
35f0df96
...
@@ -90,7 +90,7 @@ class CheckAndRaise(COp):
...
@@ -90,7 +90,7 @@ class CheckAndRaise(COp):
[
value
.
type
()],
[
value
.
type
()],
)
)
def
perform
(
self
,
node
,
inputs
,
outputs
,
params
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
out
,)
=
outputs
(
out
,)
=
outputs
val
,
*
conds
=
inputs
val
,
*
conds
=
inputs
out
[
0
]
=
val
out
[
0
]
=
val
...
...
pytensor/scan/op.py
浏览文件 @
35f0df96
...
@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rval
.
lazy
=
False
rval
.
lazy
=
False
return
rval
return
rval
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
"""Compute the scan operation in Python.
"""Compute the scan operation in Python.
The `inputs` are packed like this:
The `inputs` are packed like this:
...
...
pytensor/tensor/basic.py
浏览文件 @
35f0df96
...
@@ -3991,11 +3991,11 @@ class AllocEmpty(COp):
...
@@ -3991,11 +3991,11 @@ class AllocEmpty(COp):
output
.
tag
.
nan_guard_mode_check
=
False
output
.
tag
.
nan_guard_mode_check
=
False
return
Apply
(
self
,
_shape
,
[
output
])
return
Apply
(
self
,
_shape
,
[
output
])
def
debug_perform
(
self
,
node
,
inputs
,
out_
,
params
):
def
debug_perform
(
self
,
node
,
inputs
,
out_
):
self
.
perform
(
node
,
inputs
,
out_
,
params
)
self
.
perform
(
node
,
inputs
,
out_
)
out_
[
0
][
0
]
.
fill
(
-
123456789
)
out_
[
0
][
0
]
.
fill
(
-
123456789
)
def
perform
(
self
,
node
,
inputs
,
out_
,
params
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
...
...
pytensor/tensor/blas.py
浏览文件 @
35f0df96
...
@@ -207,7 +207,7 @@ class Gemv(Op):
...
@@ -207,7 +207,7 @@ class Gemv(Op):
return
Apply
(
self
,
inputs
,
[
y
.
type
()])
return
Apply
(
self
,
inputs
,
[
y
.
type
()])
def
perform
(
self
,
node
,
inputs
,
out_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
out_storage
):
y
,
alpha
,
A
,
x
,
beta
=
inputs
y
,
alpha
,
A
,
x
,
beta
=
inputs
if
(
if
(
have_fblas
have_fblas
...
@@ -309,7 +309,7 @@ class Ger(Op):
...
@@ -309,7 +309,7 @@ class Ger(Op):
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out
):
cA
,
calpha
,
cx
,
cy
=
inp
cA
,
calpha
,
cx
,
cy
=
inp
(
cZ
,)
=
out
(
cZ
,)
=
out
if
self
.
destructive
:
if
self
.
destructive
:
...
@@ -912,12 +912,12 @@ class Gemm(GemmRelated):
...
@@ -912,12 +912,12 @@ class Gemm(GemmRelated):
output
=
z
.
type
()
output
=
z
.
type
()
return
Apply
(
self
,
inputs
,
[
output
])
return
Apply
(
self
,
inputs
,
[
output
])
def
perform
(
self
,
node
,
inp
,
out
,
params
):
def
perform
(
self
,
node
,
inp
,
out
):
z
,
a
,
x
,
y
,
b
=
inp
z
,
a
,
x
,
y
,
b
=
inp
(
zout
,)
=
out
(
zout
,)
=
out
assert
a
.
shape
==
()
assert
a
.
shape
==
()
assert
b
.
shape
==
()
assert
b
.
shape
==
()
if
not
params
.
inplace
:
if
not
self
.
inplace
:
z
=
z
.
copy
()
# the original z will not be changed
z
=
z
.
copy
()
# the original z will not be changed
if
z
.
shape
==
():
if
z
.
shape
==
():
z
.
itemset
(
z
*
a
+
b
*
np
.
dot
(
x
,
y
))
z
.
itemset
(
z
*
a
+
b
*
np
.
dot
(
x
,
y
))
...
...
pytensor/tensor/elemwise.py
浏览文件 @
35f0df96
...
@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp):
...
@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp):
return
f
"Transpose{{axes={self.shuffle}}}"
return
f
"Transpose{{axes={self.shuffle}}}"
return
f
"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
return
f
"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out
):
(
res
,)
=
inp
(
res
,)
=
inp
(
storage
,)
=
out
(
storage
,)
=
out
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
35f0df96
...
@@ -145,7 +145,7 @@ class SearchsortedOp(COp):
...
@@ -145,7 +145,7 @@ class SearchsortedOp(COp):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
1
]]
return
[
shapes
[
1
]]
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
=
inputs
[
0
]
x
=
inputs
[
0
]
v
=
inputs
[
1
]
v
=
inputs
[
1
]
if
len
(
node
.
inputs
)
==
3
:
if
len
(
node
.
inputs
)
==
3
:
...
@@ -154,7 +154,7 @@ class SearchsortedOp(COp):
...
@@ -154,7 +154,7 @@ class SearchsortedOp(COp):
sorter
=
None
sorter
=
None
z
=
output_storage
[
0
]
z
=
output_storage
[
0
]
z
[
0
]
=
np
.
searchsorted
(
x
,
v
,
side
=
params
,
sorter
=
sorter
)
.
astype
(
z
[
0
]
=
np
.
searchsorted
(
x
,
v
,
side
=
self
.
side
,
sorter
=
sorter
)
.
astype
(
node
.
outputs
[
0
]
.
dtype
node
.
outputs
[
0
]
.
dtype
)
)
...
@@ -310,7 +310,7 @@ class CumOp(COp):
...
@@ -310,7 +310,7 @@ class CumOp(COp):
return
Apply
(
self
,
[
x
],
[
out_type
])
return
Apply
(
self
,
[
x
],
[
out_type
])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
=
inputs
[
0
]
x
=
inputs
[
0
]
z
=
output_storage
[
0
]
z
=
output_storage
[
0
]
if
self
.
mode
==
"add"
:
if
self
.
mode
==
"add"
:
...
...
pytensor/tensor/math.py
浏览文件 @
35f0df96
...
@@ -152,9 +152,9 @@ class MaxAndArgmax(COp):
...
@@ -152,9 +152,9 @@ class MaxAndArgmax(COp):
]
]
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
def
perform
(
self
,
node
,
inp
,
outs
,
params
):
def
perform
(
self
,
node
,
inp
,
outs
):
x
=
inp
[
0
]
x
=
inp
[
0
]
axes
=
param
s
axes
=
self
.
axi
s
max
,
max_idx
=
outs
max
,
max_idx
=
outs
if
axes
is
None
:
if
axes
is
None
:
axes
=
tuple
(
range
(
x
.
ndim
))
axes
=
tuple
(
range
(
x
.
ndim
))
...
@@ -374,7 +374,7 @@ class Argmax(COp):
...
@@ -374,7 +374,7 @@ class Argmax(COp):
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
)
)
def
perform
(
self
,
node
,
inp
,
outs
,
params
):
def
perform
(
self
,
node
,
inp
,
outs
):
(
x
,)
=
inp
(
x
,)
=
inp
axes
=
self
.
axis
axes
=
self
.
axis
(
max_idx
,)
=
outs
(
max_idx
,)
=
outs
...
...
pytensor/tensor/rewriting/uncanonicalize.py
浏览文件 @
35f0df96
...
@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node):
...
@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node):
If we don't use the argmax, change it to a max only.
If we don't use the argmax, change it to a max only.
"""
"""
if
isinstance
(
node
.
op
,
MaxAndArgmax
):
if
isinstance
(
node
.
op
,
MaxAndArgmax
):
axis
=
node
.
op
.
get_params
(
node
)
axis
=
node
.
op
.
axis
if
len
(
fgraph
.
clients
[
node
.
outputs
[
1
]])
==
0
:
if
len
(
fgraph
.
clients
[
node
.
outputs
[
1
]])
==
0
:
new
=
Max
(
axis
)(
node
.
inputs
[
0
])
new
=
Max
(
axis
)(
node
.
inputs
[
0
])
copy_stack_trace
(
node
.
outputs
[
0
],
new
)
copy_stack_trace
(
node
.
outputs
[
0
],
new
)
...
...
pytensor/tensor/shape.py
浏览文件 @
35f0df96
...
@@ -237,7 +237,7 @@ class Shape_i(COp):
...
@@ -237,7 +237,7 @@ class Shape_i(COp):
raise
TypeError
(
f
"{x} has too few dimensions for Shape_i"
)
raise
TypeError
(
f
"{x} has too few dimensions for Shape_i"
)
return
Apply
(
self
,
[
x
],
[
pytensor
.
tensor
.
type
.
lscalar
()])
return
Apply
(
self
,
[
x
],
[
pytensor
.
tensor
.
type
.
lscalar
()])
def
perform
(
self
,
node
,
inp
,
out_
,
params
):
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
x
,)
=
inp
(
out
,)
=
out_
(
out
,)
=
out_
if
out
[
0
]
is
None
:
if
out
[
0
]
is
None
:
...
@@ -668,7 +668,7 @@ class Reshape(COp):
...
@@ -668,7 +668,7 @@ class Reshape(COp):
return
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)])
return
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)])
def
perform
(
self
,
node
,
inp
,
out_
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
shp
=
inp
x
,
shp
=
inp
(
out
,)
=
out_
(
out
,)
=
out_
if
len
(
shp
)
!=
self
.
ndim
:
if
len
(
shp
)
!=
self
.
ndim
:
...
...
pytensor/tensor/subtensor.py
浏览文件 @
35f0df96
...
@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp):
...
@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
8
,)
return
(
8
,)
def
perform
(
self
,
node
,
inp
,
out_
,
params
):
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
y
,
idx
=
inp
x
,
y
,
idx
=
inp
(
out
,)
=
out_
(
out
,)
=
out_
if
not
self
.
inplace
:
if
not
self
.
inplace
:
...
...
tests/link/c/test_params_type.py
浏览文件 @
35f0df96
...
@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp):
...
@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp):
x
=
at
.
as_tensor_variable
(
x
)
x
=
at
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
coefficients
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
coefficients
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
x
=
inputs
[
0
]
x
=
inputs
[
0
]
y
=
output_storage
[
0
]
y
=
output_storage
[
0
]
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
...
@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp):
...
@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp):
x
=
at
.
as_tensor_variable
(
x
)
x
=
at
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
coefficients
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
coefficients
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
x
=
inputs
[
0
]
x
=
inputs
[
0
]
y
=
output_storage
[
0
]
y
=
output_storage
[
0
]
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
...
...
tests/link/c/test_type.py
浏览文件 @
35f0df96
...
@@ -117,7 +117,8 @@ class MyOpEnumList(COp):
...
@@ -117,7 +117,8 @@ class MyOpEnumList(COp):
def
make_node
(
self
,
a
,
b
):
def
make_node
(
self
,
a
,
b
):
return
Apply
(
self
,
[
aes
.
as_scalar
(
a
),
aes
.
as_scalar
(
b
)],
[
aes
.
float64
()])
return
Apply
(
self
,
[
aes
.
as_scalar
(
a
),
aes
.
as_scalar
(
b
)],
[
aes
.
float64
()])
def
perform
(
self
,
node
,
inputs
,
outputs
,
op
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
op
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
a
,
b
=
inputs
a
,
b
=
inputs
(
o
,)
=
outputs
(
o
,)
=
outputs
if
op
==
self
.
params_type
.
ADD
:
if
op
==
self
.
params_type
.
ADD
:
...
...
tests/tensor/test_blockwise.py
浏览文件 @
35f0df96
...
@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node
...
@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node
from
pytensor.tensor
import
diagonal
,
log
,
tensor
from
pytensor.tensor
import
diagonal
,
log
,
tensor
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.shape
import
Shape
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve_triangular
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve_triangular
from
pytensor.tensor.utils
import
_parse_gufunc_signature
from
pytensor.tensor.utils
import
_parse_gufunc_signature
...
@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
...
@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
fn
=
pytensor
.
function
([
value
,
mu
,
cov
],
[
logp
,
*
dlogp
])
fn
=
pytensor
.
function
([
value
,
mu
,
cov
],
[
logp
,
*
dlogp
])
benchmark
(
fn
,
*
test_values
)
benchmark
(
fn
,
*
test_values
)
def
test_op_with_params
():
matrix_shape_blockwise
=
Blockwise
(
core_op
=
Shape
(),
signature
=
"(x1,x2)->(s)"
)
x
=
tensor
(
"x"
,
shape
=
(
5
,
None
,
None
),
dtype
=
"float64"
)
x_shape
=
matrix_shape_blockwise
(
x
)
fn
=
pytensor
.
function
([
x
],
x_shape
)
pytensor
.
dprint
(
fn
)
# Assert blockwise
print
(
fn
(
np
.
zeros
((
5
,
3
,
2
))))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论