Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
35f0df96
提交
35f0df96
authored
11月 22, 2023
作者:
Ricardo Vieira
提交者:
Thomas Wiecki
11月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make params exclusive to COp's
Also removes them from the signature of perform
上级
5f840278
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
105 行增加
和
132 行删除
+105
-132
using_params.rst
doc/extending/using_params.rst
+17
-21
basic.py
pytensor/graph/basic.py
+0
-11
op.py
pytensor/graph/op.py
+9
-56
basic.py
pytensor/link/c/basic.py
+11
-2
interface.py
pytensor/link/c/interface.py
+24
-1
op.py
pytensor/link/c/op.py
+2
-2
basic.py
pytensor/link/numba/dispatch/basic.py
+5
-16
raise_op.py
pytensor/raise_op.py
+1
-1
op.py
pytensor/scan/op.py
+1
-1
basic.py
pytensor/tensor/basic.py
+3
-3
blas.py
pytensor/tensor/blas.py
+4
-4
elemwise.py
pytensor/tensor/elemwise.py
+1
-1
extra_ops.py
pytensor/tensor/extra_ops.py
+3
-3
math.py
pytensor/tensor/math.py
+3
-3
uncanonicalize.py
pytensor/tensor/rewriting/uncanonicalize.py
+1
-1
shape.py
pytensor/tensor/shape.py
+2
-2
subtensor.py
pytensor/tensor/subtensor.py
+1
-1
test_params_type.py
tests/link/c/test_params_type.py
+4
-2
test_type.py
tests/link/c/test_type.py
+2
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+11
-0
没有找到文件。
doc/extending/using_params.rst
浏览文件 @
35f0df96
.. _extending_op_params:
===============
Using Op params
===============
===============
=
Using
C
Op params
===============
=
The Op params is a facility to pass some runtime parameters to the
The
C
Op params is a facility to pass some runtime parameters to the
code of an op without modifying it. It can enable a single instance
of C code to serve different needs and therefore reduce compilation.
...
...
@@ -53,7 +53,7 @@ following methods will be used for the type:
- :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <Type.values_eq>`
Additionally
if you want
to use your params with C code, you need to extend `COp`
Additionally
,
to use your params with C code, you need to extend `COp`
and implement the following methods:
- :meth:`c_declare <CLinkerType.c_declare>`
...
...
@@ -65,24 +65,24 @@ You can also define other convenience methods such as
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
Registering the params with your Op
-----------------------------------
Registering the params with your
C
Op
-----------------------------------
-
To declare that your
Op
uses params you have to set the class
To declare that your
`COp`
uses params you have to set the class
attribute :attr:`params_type` to an instance of your params Type.
.. note::
If you want to have multiple parameters, PyTensor provides the convenient class
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available
in both Python (as a Python object) and
C code (as a struct).
one object that will be available
to the
C code (as a struct).
For example if we decide to use an int as the params the following
would be appropriate:
.. code-block:: python
class MyOp(Op):
class MyOp(
C
Op):
params_type = Generic()
After that you need to define a :meth:`get_params` method on your
...
...
@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
signature of :meth:`perform`. The new expected signature will have an
extra parameter at the end which corresponds to the params object.
.. warning::
If you do not account for this extra parameter, the code will fail
at runtime if it tries to run the python version.
Also, for the C code, the `sub` dictionary will contain an extra entry
The `sub` dictionary for `COp`s with params will contain an extra entry
`'params'` which will map to the variable name of the params object.
This is true for all methods that receive a `sub` parameter, so this
means that you can use your params in the :meth:`c_code <COp.c_code>`
...
...
@@ -131,7 +126,7 @@ A simple example
----------------
This is a simple example which uses a params object to pass a value.
This `Op` will multiply a scalar input by a fixed floating point value.
This `
C
Op` will multiply a scalar input by a fixed floating point value.
Since the value in this case is a python float, we chose Generic as
the params type.
...
...
@@ -156,9 +151,10 @@ the params type.
inp = as_scalar(inp)
return Apply(self, [inp], [inp.type()])
def perform(self, node, inputs, output_storage, params):
# Here params is a python float so this is ok
output_storage[0][0] = inputs[0] * params
def perform(self, node, inputs, output_storage):
# Because params is a python float we can use `self.mul` directly.
# If it's something fancier, call `self.params_type.filter(self.get_params(node))`
output_storage[0][0] = inputs[0] * self.mul
def c_code(self, node, name, inputs, outputs, sub):
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
...
...
@@ -174,7 +170,7 @@ weights.
.. testcode::
from pytensor.
graph.op import
Op
from pytensor.
link.c.op import C
Op
from pytensor.link.c.type import Generic
from pytensor.scalar import as_scalar
...
...
pytensor/graph/basic.py
浏览文件 @
35f0df96
...
...
@@ -30,7 +30,6 @@ import numpy as np
from
pytensor.configdefaults
import
config
from
pytensor.graph.utils
import
(
MetaObject
,
MethodNotDefined
,
Scratchpad
,
TestValueError
,
ValidatingScratchpad
,
...
...
@@ -151,16 +150,6 @@ class Apply(Node, Generic[OpType]):
f
"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)
def
run_params
(
self
):
"""
Returns the params for the node, or NoParams if no params is set.
"""
try
:
return
self
.
op
.
get_params
(
self
)
except
MethodNotDefined
:
return
NoParams
def
__getstate__
(
self
):
d
=
self
.
__dict__
# ufunc don't pickle/unpickle well
...
...
pytensor/graph/op.py
浏览文件 @
35f0df96
...
...
@@ -16,15 +16,13 @@ from typing import (
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
NoParams
,
Variable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.utils
import
(
MetaObject
,
MethodNotDefined
,
TestValueError
,
add_tag_trace
,
get_variable_trace_string
,
)
from
pytensor.link.c.params_type
import
Params
,
ParamsType
if
TYPE_CHECKING
:
...
...
@@ -37,10 +35,7 @@ StorageMapType = dict[Variable, StorageCellType]
ComputeMapType
=
dict
[
Variable
,
list
[
bool
]]
InputStorageType
=
list
[
StorageCellType
]
OutputStorageType
=
list
[
StorageCellType
]
ParamsInputType
=
Optional
[
tuple
[
Any
,
...
]]
PerformMethodType
=
Callable
[
[
Apply
,
list
[
Any
],
OutputStorageType
,
ParamsInputType
],
None
]
PerformMethodType
=
Callable
[[
Apply
,
list
[
Any
],
OutputStorageType
],
None
]
BasicThunkType
=
Callable
[[],
None
]
ThunkCallableType
=
Callable
[
[
PerformMethodType
,
StorageMapType
,
ComputeMapType
,
Apply
],
None
...
...
@@ -202,7 +197,6 @@ class Op(MetaObject):
itypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
otypes
:
Optional
[
Sequence
[
"Type"
]]
=
None
params_type
:
Optional
[
ParamsType
]
=
None
_output_type_depends_on_input_value
=
False
"""
...
...
@@ -426,7 +420,6 @@ class Op(MetaObject):
node
:
Apply
,
inputs
:
Sequence
[
Any
],
output_storage
:
OutputStorageType
,
params
:
ParamsInputType
=
None
,
)
->
None
:
"""Calculate the function on the inputs and put the variables in the output storage.
...
...
@@ -442,8 +435,6 @@ class Op(MetaObject):
these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists.
params
A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes
-----
...
...
@@ -481,22 +472,6 @@ class Op(MetaObject):
"""
return
True
def
get_params
(
self
,
node
:
Apply
)
->
Params
:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if
isinstance
(
self
.
params_type
,
ParamsType
):
wrapper
=
self
.
params_type
if
not
all
(
hasattr
(
self
,
field
)
for
field
in
wrapper
.
fields
):
# Let's print missing attributes for debugging.
not_found
=
tuple
(
field
for
field
in
wrapper
.
fields
if
not
hasattr
(
self
,
field
)
)
raise
AttributeError
(
f
"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return
self
.
params_type
.
get_params
(
self
)
raise
MethodNotDefined
(
"get_params"
)
def
prepare_node
(
self
,
node
:
Apply
,
...
...
@@ -538,34 +513,12 @@ class Op(MetaObject):
else
:
p
=
node
.
op
.
perform
params
=
node
.
run_params
()
if
params
is
NoParams
:
# default arguments are stored in the closure of `rval`
@is_thunk_type
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
params
=
None
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
return
r
else
:
params_val
=
node
.
params_type
.
filter
(
params
)
@is_thunk_type
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
params
=
params_val
,
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
,
params
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
return
r
@is_thunk_type
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
return
r
rval
.
inputs
=
node_input_storage
rval
.
outputs
=
node_output_storage
...
...
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
"No Python implementation is provided by this Op."
)
...
...
pytensor/link/c/basic.py
浏览文件 @
35f0df96
...
...
@@ -20,6 +20,7 @@ from pytensor.graph.basic import (
io_toposort
,
vars_between
,
)
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.link.basic
import
Container
,
Linker
,
LocalLinker
,
PerformLinker
from
pytensor.link.c.cmodule
import
(
METH_VARARGS
,
...
...
@@ -617,7 +618,12 @@ class CLinker(Linker):
# that needs it
self
.
node_params
=
dict
()
for
node
in
self
.
node_order
:
params
=
node
.
run_params
()
if
not
isinstance
(
node
.
op
,
CLinkerOp
):
continue
try
:
params
=
node
.
op
.
get_params
(
node
)
except
MethodNotDefined
:
params
=
NoParams
if
params
is
not
NoParams
:
# try to avoid creating more than one variable for the
# same params.
...
...
@@ -803,7 +809,10 @@ class CLinker(Linker):
sub
=
dict
(
failure_var
=
failure_var
)
params
=
node
.
run_params
()
try
:
params
=
op
.
get_params
(
node
)
except
MethodNotDefined
:
params
=
NoParams
if
params
is
not
NoParams
:
params_var
=
symbol
[
self
.
node_params
[
params
]]
...
...
pytensor/link/c/interface.py
浏览文件 @
35f0df96
import
typing
import
warnings
from
abc
import
abstractmethod
from
typing
import
Callable
from
typing
import
Callable
,
Optional
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.utils
import
MethodNotDefined
if
typing
.
TYPE_CHECKING
:
from
pytensor.link.c.params_type
import
Params
,
ParamsType
class
CLinkerObject
:
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""
...
...
@@ -172,6 +177,8 @@ class CLinkerObject:
class
CLinkerOp
(
CLinkerObject
):
"""Interface definition for `Op` subclasses compiled by `CLinker`."""
params_type
:
Optional
[
"ParamsType"
]
=
None
@abstractmethod
def
c_code
(
self
,
...
...
@@ -362,6 +369,22 @@ class CLinkerOp(CLinkerObject):
"""
return
""
def
get_params
(
self
,
node
:
Apply
)
->
"Params"
:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if
self
.
params_type
is
not
None
:
wrapper
=
self
.
params_type
if
not
all
(
hasattr
(
self
,
field
)
for
field
in
wrapper
.
fields
):
# Let's print missing attributes for debugging.
not_found
=
tuple
(
field
for
field
in
wrapper
.
fields
if
not
hasattr
(
self
,
field
)
)
raise
AttributeError
(
f
"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return
self
.
params_type
.
get_params
(
self
)
raise
MethodNotDefined
(
"get_params"
)
class
CLinkerType
(
CLinkerObject
):
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
...
...
pytensor/link/c/op.py
浏览文件 @
35f0df96
...
...
@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
"No Python implementation is provided by this COp."
)
...
...
@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
"""
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
raise
NotImplementedError
(
"No Python implementation is provided by this ExternalCOp."
)
pytensor/link/numba/dispatch/basic.py
浏览文件 @
35f0df96
...
...
@@ -21,7 +21,7 @@ from numba.extending import box, overload
from
pytensor
import
config
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.graph.basic
import
Apply
,
NoParams
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.type
import
Type
from
pytensor.ifelse
import
IfElse
...
...
@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
output_types
=
tuple
(
out
.
type
for
out
in
node
.
outputs
)
params
=
node
.
run_params
()
if
params
is
not
NoParams
:
params_val
=
dict
(
node
.
params_type
.
filter
(
params
))
def
py_perform
(
inputs
):
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
,
params_val
)
return
outputs
else
:
def
py_perform
(
inputs
):
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
)
return
outputs
def
py_perform
(
inputs
):
outputs
=
[[
None
]
for
i
in
range
(
n_outputs
)]
op
.
perform
(
node
,
inputs
,
outputs
)
return
outputs
if
n_outputs
==
1
:
...
...
pytensor/raise_op.py
浏览文件 @
35f0df96
...
...
@@ -90,7 +90,7 @@ class CheckAndRaise(COp):
[
value
.
type
()],
)
def
perform
(
self
,
node
,
inputs
,
outputs
,
params
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
out
,)
=
outputs
val
,
*
conds
=
inputs
out
[
0
]
=
val
...
...
pytensor/scan/op.py
浏览文件 @
35f0df96
...
...
@@ -1658,7 +1658,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rval
.
lazy
=
False
return
rval
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
"""Compute the scan operation in Python.
The `inputs` are packed like this:
...
...
pytensor/tensor/basic.py
浏览文件 @
35f0df96
...
...
@@ -3991,11 +3991,11 @@ class AllocEmpty(COp):
output
.
tag
.
nan_guard_mode_check
=
False
return
Apply
(
self
,
_shape
,
[
output
])
def
debug_perform
(
self
,
node
,
inputs
,
out_
,
params
):
self
.
perform
(
node
,
inputs
,
out_
,
params
)
def
debug_perform
(
self
,
node
,
inputs
,
out_
):
self
.
perform
(
node
,
inputs
,
out_
)
out_
[
0
][
0
]
.
fill
(
-
123456789
)
def
perform
(
self
,
node
,
inputs
,
out_
,
params
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
if
out
[
0
]
is
None
or
out
[
0
]
.
shape
!=
sh
:
...
...
pytensor/tensor/blas.py
浏览文件 @
35f0df96
...
...
@@ -207,7 +207,7 @@ class Gemv(Op):
return
Apply
(
self
,
inputs
,
[
y
.
type
()])
def
perform
(
self
,
node
,
inputs
,
out_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
out_storage
):
y
,
alpha
,
A
,
x
,
beta
=
inputs
if
(
have_fblas
...
...
@@ -309,7 +309,7 @@ class Ger(Op):
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out
):
cA
,
calpha
,
cx
,
cy
=
inp
(
cZ
,)
=
out
if
self
.
destructive
:
...
...
@@ -912,12 +912,12 @@ class Gemm(GemmRelated):
output
=
z
.
type
()
return
Apply
(
self
,
inputs
,
[
output
])
def
perform
(
self
,
node
,
inp
,
out
,
params
):
def
perform
(
self
,
node
,
inp
,
out
):
z
,
a
,
x
,
y
,
b
=
inp
(
zout
,)
=
out
assert
a
.
shape
==
()
assert
b
.
shape
==
()
if
not
params
.
inplace
:
if
not
self
.
inplace
:
z
=
z
.
copy
()
# the original z will not be changed
if
z
.
shape
==
():
z
.
itemset
(
z
*
a
+
b
*
np
.
dot
(
x
,
y
))
...
...
pytensor/tensor/elemwise.py
浏览文件 @
35f0df96
...
...
@@ -233,7 +233,7 @@ class DimShuffle(ExternalCOp):
return
f
"Transpose{{axes={self.shuffle}}}"
return
f
"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out
):
(
res
,)
=
inp
(
storage
,)
=
out
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
35f0df96
...
...
@@ -145,7 +145,7 @@ class SearchsortedOp(COp):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
1
]]
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
=
inputs
[
0
]
v
=
inputs
[
1
]
if
len
(
node
.
inputs
)
==
3
:
...
...
@@ -154,7 +154,7 @@ class SearchsortedOp(COp):
sorter
=
None
z
=
output_storage
[
0
]
z
[
0
]
=
np
.
searchsorted
(
x
,
v
,
side
=
params
,
sorter
=
sorter
)
.
astype
(
z
[
0
]
=
np
.
searchsorted
(
x
,
v
,
side
=
self
.
side
,
sorter
=
sorter
)
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
...
...
@@ -310,7 +310,7 @@ class CumOp(COp):
return
Apply
(
self
,
[
x
],
[
out_type
])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
=
inputs
[
0
]
z
=
output_storage
[
0
]
if
self
.
mode
==
"add"
:
...
...
pytensor/tensor/math.py
浏览文件 @
35f0df96
...
...
@@ -152,9 +152,9 @@ class MaxAndArgmax(COp):
]
return
Apply
(
self
,
inputs
,
outputs
)
def
perform
(
self
,
node
,
inp
,
outs
,
params
):
def
perform
(
self
,
node
,
inp
,
outs
):
x
=
inp
[
0
]
axes
=
param
s
axes
=
self
.
axi
s
max
,
max_idx
=
outs
if
axes
is
None
:
axes
=
tuple
(
range
(
x
.
ndim
))
...
...
@@ -374,7 +374,7 @@ class Argmax(COp):
"You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
)
def
perform
(
self
,
node
,
inp
,
outs
,
params
):
def
perform
(
self
,
node
,
inp
,
outs
):
(
x
,)
=
inp
axes
=
self
.
axis
(
max_idx
,)
=
outs
...
...
pytensor/tensor/rewriting/uncanonicalize.py
浏览文件 @
35f0df96
...
...
@@ -48,7 +48,7 @@ def local_max_and_argmax(fgraph, node):
If we don't use the argmax, change it to a max only.
"""
if
isinstance
(
node
.
op
,
MaxAndArgmax
):
axis
=
node
.
op
.
get_params
(
node
)
axis
=
node
.
op
.
axis
if
len
(
fgraph
.
clients
[
node
.
outputs
[
1
]])
==
0
:
new
=
Max
(
axis
)(
node
.
inputs
[
0
])
copy_stack_trace
(
node
.
outputs
[
0
],
new
)
...
...
pytensor/tensor/shape.py
浏览文件 @
35f0df96
...
...
@@ -237,7 +237,7 @@ class Shape_i(COp):
raise
TypeError
(
f
"{x} has too few dimensions for Shape_i"
)
return
Apply
(
self
,
[
x
],
[
pytensor
.
tensor
.
type
.
lscalar
()])
def
perform
(
self
,
node
,
inp
,
out_
,
params
):
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
if
out
[
0
]
is
None
:
...
...
@@ -668,7 +668,7 @@ class Reshape(COp):
return
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)])
def
perform
(
self
,
node
,
inp
,
out_
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
shp
=
inp
(
out
,)
=
out_
if
len
(
shp
)
!=
self
.
ndim
:
...
...
pytensor/tensor/subtensor.py
浏览文件 @
35f0df96
...
...
@@ -2474,7 +2474,7 @@ class AdvancedIncSubtensor1(COp):
def
c_code_cache_version
(
self
):
return
(
8
,)
def
perform
(
self
,
node
,
inp
,
out_
,
params
):
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
y
,
idx
=
inp
(
out
,)
=
out_
if
not
self
.
inplace
:
...
...
tests/link/c/test_params_type.py
浏览文件 @
35f0df96
...
...
@@ -31,7 +31,8 @@ class QuadraticOpFunc(COp):
x
=
at
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
coefficients
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
coefficients
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
x
=
inputs
[
0
]
y
=
output_storage
[
0
]
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
...
...
@@ -117,7 +118,8 @@ class QuadraticCOpFunc(ExternalCOp):
x
=
at
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
,
coefficients
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
coefficients
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
x
=
inputs
[
0
]
y
=
output_storage
[
0
]
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
...
...
tests/link/c/test_type.py
浏览文件 @
35f0df96
...
...
@@ -117,7 +117,8 @@ class MyOpEnumList(COp):
def
make_node
(
self
,
a
,
b
):
return
Apply
(
self
,
[
aes
.
as_scalar
(
a
),
aes
.
as_scalar
(
b
)],
[
aes
.
float64
()])
def
perform
(
self
,
node
,
inputs
,
outputs
,
op
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
op
=
self
.
params_type
.
filter
(
self
.
get_params
(
node
))
a
,
b
=
inputs
(
o
,)
=
outputs
if
op
==
self
.
params_type
.
ADD
:
...
...
tests/tensor/test_blockwise.py
浏览文件 @
35f0df96
...
...
@@ -12,6 +12,7 @@ from pytensor.graph.replace import vectorize_node
from
pytensor.tensor
import
diagonal
,
log
,
tensor
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.shape
import
Shape
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve_triangular
from
pytensor.tensor.utils
import
_parse_gufunc_signature
...
...
@@ -359,3 +360,13 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
fn
=
pytensor
.
function
([
value
,
mu
,
cov
],
[
logp
,
*
dlogp
])
benchmark
(
fn
,
*
test_values
)
def
test_op_with_params
():
matrix_shape_blockwise
=
Blockwise
(
core_op
=
Shape
(),
signature
=
"(x1,x2)->(s)"
)
x
=
tensor
(
"x"
,
shape
=
(
5
,
None
,
None
),
dtype
=
"float64"
)
x_shape
=
matrix_shape_blockwise
(
x
)
fn
=
pytensor
.
function
([
x
],
x_shape
)
pytensor
.
dprint
(
fn
)
# Assert blockwise
print
(
fn
(
np
.
zeros
((
5
,
3
,
2
))))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论