Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
01d20497
提交
01d20497
authored
5月 07, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
6月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Parameterize Variable type by Type and Apply
上级
de9ad202
显示空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
79 行增加
和
59 行删除
+79
-59
basic.py
aesara/graph/basic.py
+34
-28
type.py
aesara/graph/type.py
+1
-1
basic.py
aesara/link/basic.py
+1
-1
basic.py
aesara/scalar/basic.py
+1
-1
op.py
aesara/scan/op.py
+1
-1
opt.py
aesara/scan/opt.py
+7
-4
io.py
aesara/tensor/io.py
+4
-4
var.py
aesara/tensor/var.py
+18
-7
how_to_make_ops.rst
doc/sandbox/how_to_make_ops.rst
+2
-2
test_basic.py
tests/graph/test_basic.py
+1
-1
test_io.py
tests/tensor/test_io.py
+4
-4
test_shape.py
tests/tensor/test_shape.py
+2
-2
test_var.py
tests/tensor/test_var.py
+3
-3
没有找到文件。
aesara/graph/basic.py
浏览文件 @
01d20497
...
@@ -6,6 +6,7 @@ from copy import copy
...
@@ -6,6 +6,7 @@ from copy import copy
from
itertools
import
count
from
itertools
import
count
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Callable
,
Callable
,
Collection
,
Collection
,
Deque
,
Deque
,
...
@@ -47,6 +48,9 @@ if TYPE_CHECKING:
...
@@ -47,6 +48,9 @@ if TYPE_CHECKING:
OpType
=
TypeVar
(
"OpType"
,
bound
=
"Op"
)
OpType
=
TypeVar
(
"OpType"
,
bound
=
"Op"
)
OptionalApplyType
=
TypeVar
(
"OptionalApplyType"
,
None
,
"Apply"
,
covariant
=
True
)
_TypeType
=
TypeVar
(
"_TypeType"
,
bound
=
"Type"
)
_IdType
=
TypeVar
(
"_IdType"
,
bound
=
Hashable
)
T
=
TypeVar
(
"T"
,
bound
=
"Node"
)
T
=
TypeVar
(
"T"
,
bound
=
"Node"
)
NoParams
=
object
()
NoParams
=
object
()
...
@@ -61,7 +65,6 @@ class Node(MetaObject):
...
@@ -61,7 +65,6 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
"""
type
:
"Type"
name
:
Optional
[
str
]
name
:
Optional
[
str
]
def
get_parents
(
self
):
def
get_parents
(
self
):
...
@@ -110,7 +113,10 @@ class Apply(Node, Generic[OpType]):
...
@@ -110,7 +113,10 @@ class Apply(Node, Generic[OpType]):
"""
"""
def
__init__
(
def
__init__
(
self
,
op
:
OpType
,
inputs
:
Sequence
[
"Variable"
],
outputs
:
Sequence
[
"Variable"
]
self
,
op
:
OpType
,
inputs
:
Sequence
[
"Variable"
],
outputs
:
Sequence
[
"Variable"
],
):
):
if
not
isinstance
(
inputs
,
Sequence
):
if
not
isinstance
(
inputs
,
Sequence
):
raise
TypeError
(
"The inputs of an Apply must be a sequence type"
)
raise
TypeError
(
"The inputs of an Apply must be a sequence type"
)
...
@@ -309,7 +315,7 @@ class Apply(Node, Generic[OpType]):
...
@@ -309,7 +315,7 @@ class Apply(Node, Generic[OpType]):
return
self
.
op
.
params_type
return
self
.
op
.
params_type
class
Variable
(
Node
):
class
Variable
(
Node
,
Generic
[
_TypeType
,
OptionalApplyType
]
):
r"""
r"""
A :term:`Variable` is a node in an expression graph that represents a
A :term:`Variable` is a node in an expression graph that represents a
variable.
variable.
...
@@ -407,10 +413,10 @@ class Variable(Node):
...
@@ -407,10 +413,10 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name']
# __slots__ = ['type', 'owner', 'index', 'name']
__count__
=
count
(
0
)
__count__
=
count
(
0
)
_owner
:
Optional
[
Apply
]
_owner
:
Optional
ApplyType
@property
@property
def
owner
(
self
)
->
Optional
[
Apply
]
:
def
owner
(
self
)
->
Optional
ApplyType
:
return
self
.
_owner
return
self
.
_owner
@owner.setter
@owner.setter
...
@@ -427,30 +433,31 @@ class Variable(Node):
...
@@ -427,30 +433,31 @@ class Variable(Node):
def
__init__
(
def
__init__
(
self
,
self
,
type
,
type
:
_TypeType
,
owner
:
Optional
[
Apply
]
=
Non
e
,
owner
:
Optional
ApplyTyp
e
,
index
:
Optional
[
int
]
=
None
,
index
:
Optional
[
int
]
=
None
,
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
):
)
->
None
:
super
()
.
__init__
()
super
()
.
__init__
()
self
.
tag
=
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
self
.
tag
=
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
self
.
type
=
type
self
.
type
=
type
self
.
_owner
=
owner
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
raise
TypeError
(
"owner must be an Apply instance"
)
self
.
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
raise
TypeError
(
"index must be an int"
)
self
.
index
=
index
self
.
index
=
index
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
raise
TypeError
(
"name must be a string"
,
name
)
raise
TypeError
(
"name must be a string"
)
self
.
name
=
name
self
.
name
=
name
self
.
auto_name
=
"auto_"
+
str
(
next
(
self
.
__count__
))
self
.
auto_name
=
f
"auto_{next(self.__count__)}"
def
get_test_value
(
self
):
def
get_test_value
(
self
):
"""Get the test value.
"""Get the test value.
...
@@ -516,7 +523,6 @@ class Variable(Node):
...
@@ -516,7 +523,6 @@ class Variable(Node):
Tags and names are copied to the returned instance.
Tags and names are copied to the returned instance.
"""
"""
# return copy(self)
cp
=
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
cp
=
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
cp
.
tag
=
copy
(
self
.
tag
)
cp
.
tag
=
copy
(
self
.
tag
)
return
cp
return
cp
...
@@ -612,11 +618,11 @@ class Variable(Node):
...
@@ -612,11 +618,11 @@ class Variable(Node):
return
d
return
d
class
AtomicVariable
(
Variable
):
class
AtomicVariable
(
Variable
[
_TypeType
,
None
]
):
"""A node type that has no ancestors and should never be considered an input to a graph."""
"""A node type that has no ancestors and should never be considered an input to a graph."""
def
__init__
(
self
,
type
,
**
kwargs
):
def
__init__
(
self
,
type
:
_TypeType
,
**
kwargs
):
super
()
.
__init__
(
type
,
**
kwargs
)
super
()
.
__init__
(
type
,
None
,
None
,
**
kwargs
)
@abc.abstractmethod
@abc.abstractmethod
def
signature
(
self
):
def
signature
(
self
):
...
@@ -651,13 +657,13 @@ class AtomicVariable(Variable):
...
@@ -651,13 +657,13 @@ class AtomicVariable(Variable):
raise
ValueError
(
"AtomicVariable instances cannot have an index."
)
raise
ValueError
(
"AtomicVariable instances cannot have an index."
)
class
NominalVariable
(
AtomicVariable
):
class
NominalVariable
(
AtomicVariable
[
_TypeType
]
):
"""A variable that enables alpha-equivalent comparisons."""
"""A variable that enables alpha-equivalent comparisons."""
__instances__
:
Dict
[
Hashable
,
type
]
=
{}
__instances__
:
Dict
[
Tuple
[
"Type"
,
Hashable
],
"NominalVariable"
]
=
{}
def
__new__
(
cls
,
id
,
typ
,
**
kwargs
):
def
__new__
(
cls
,
id
:
_IdType
,
typ
:
_TypeType
,
**
kwargs
):
if
(
id
,
typ
)
not
in
cls
.
__instances__
:
if
(
typ
,
id
)
not
in
cls
.
__instances__
:
var_type
=
typ
.
variable_type
var_type
=
typ
.
variable_type
type_name
=
f
"Nominal{var_type.__name__}"
type_name
=
f
"Nominal{var_type.__name__}"
...
@@ -670,13 +676,13 @@ class NominalVariable(AtomicVariable):
...
@@ -670,13 +676,13 @@ class NominalVariable(AtomicVariable):
new_type
=
type
(
new_type
=
type
(
type_name
,
(
cls
,
var_type
),
{
"__reduce__"
:
_reduce
,
"__str__"
:
_str
}
type_name
,
(
cls
,
var_type
),
{
"__reduce__"
:
_reduce
,
"__str__"
:
_str
}
)
)
res
=
super
()
.
__new__
(
new_type
)
res
:
NominalVariable
=
super
()
.
__new__
(
new_type
)
cls
.
__instances__
[(
id
,
typ
)]
=
res
cls
.
__instances__
[(
typ
,
id
)]
=
res
return
cls
.
__instances__
[(
id
,
typ
)]
return
cls
.
__instances__
[(
typ
,
id
)]
def
__init__
(
self
,
id
,
typ
,
**
kwargs
):
def
__init__
(
self
,
id
:
_IdType
,
typ
:
_TypeType
,
**
kwargs
):
self
.
id
=
id
self
.
id
=
id
super
()
.
__init__
(
typ
,
**
kwargs
)
super
()
.
__init__
(
typ
,
**
kwargs
)
...
@@ -699,11 +705,11 @@ class NominalVariable(AtomicVariable):
...
@@ -699,11 +705,11 @@ class NominalVariable(AtomicVariable):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
return
f
"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def
signature
(
self
):
def
signature
(
self
)
->
Tuple
[
_TypeType
,
_IdType
]
:
return
(
self
.
type
,
self
.
id
)
return
(
self
.
type
,
self
.
id
)
class
Constant
(
AtomicVariable
):
class
Constant
(
AtomicVariable
[
_TypeType
]
):
"""A `Variable` with a fixed `data` field.
"""A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant
`Constant` nodes make numerous optimizations possible (e.g. constant
...
@@ -718,7 +724,7 @@ class Constant(AtomicVariable):
...
@@ -718,7 +724,7 @@ class Constant(AtomicVariable):
# __slots__ = ['data']
# __slots__ = ['data']
def
__init__
(
self
,
type
,
data
,
name
=
None
):
def
__init__
(
self
,
type
:
_TypeType
,
data
:
Any
,
name
:
Optional
[
str
]
=
None
):
super
()
.
__init__
(
type
,
name
=
name
)
super
()
.
__init__
(
type
,
name
=
name
)
self
.
data
=
type
.
filter
(
data
)
self
.
data
=
type
.
filter
(
data
)
add_tag_trace
(
self
)
add_tag_trace
(
self
)
...
...
aesara/graph/type.py
浏览文件 @
01d20497
...
@@ -197,7 +197,7 @@ class Type(MetaObject):
...
@@ -197,7 +197,7 @@ class Type(MetaObject):
A pretty string for printing and debugging.
A pretty string for printing and debugging.
"""
"""
return
self
.
variable_type
(
self
,
name
=
name
)
return
self
.
variable_type
(
self
,
None
,
name
=
name
)
def
make_constant
(
self
,
value
:
D
,
name
:
Optional
[
Text
]
=
None
)
->
constant_type
:
def
make_constant
(
self
,
value
:
D
,
name
:
Optional
[
Text
]
=
None
)
->
constant_type
:
"""Return a new `Constant` instance of this `Type`.
"""Return a new `Constant` instance of this `Type`.
...
...
aesara/link/basic.py
浏览文件 @
01d20497
...
@@ -207,7 +207,7 @@ class Linker(ABC):
...
@@ -207,7 +207,7 @@ class Linker(ABC):
Examples
Examples
--------
--------
x, y = Variable(Double
), Variable(Doubl
e)
x, y = Variable(Double
, None), Variable(Double, Non
e)
e = x + y
e = x + y
fgraph = FunctionGraph([x, y], [e])
fgraph = FunctionGraph([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
...
...
aesara/scalar/basic.py
浏览文件 @
01d20497
...
@@ -415,7 +415,7 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -415,7 +415,7 @@ class ScalarType(CType, HasDataType, HasShape):
return
upcast
(
*
[
x
.
dtype
for
x
in
[
self
]
+
list
(
others
)])
return
upcast
(
*
[
x
.
dtype
for
x
in
[
self
]
+
list
(
others
)])
def
make_variable
(
self
,
name
=
None
):
def
make_variable
(
self
,
name
=
None
):
return
ScalarVariable
(
self
,
name
=
name
)
return
ScalarVariable
(
self
,
None
,
name
=
name
)
def
__str__
(
self
):
def
__str__
(
self
):
return
str
(
self
.
dtype
)
return
str
(
self
.
dtype
)
...
...
aesara/scan/op.py
浏览文件 @
01d20497
...
@@ -1483,7 +1483,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1483,7 +1483,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def
inner_outputs
(
self
):
def
inner_outputs
(
self
):
return
self
.
fgraph
.
outputs
return
self
.
fgraph
.
outputs
def
clone
(
self
):
def
clone
(
self
)
->
"Scan"
:
res
=
copy
(
self
)
res
=
copy
(
self
)
res
.
fgraph
=
res
.
fgraph
.
clone
()
res
.
fgraph
=
res
.
fgraph
.
clone
()
return
res
return
res
...
...
aesara/scan/opt.py
浏览文件 @
01d20497
...
@@ -939,7 +939,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -939,7 +939,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
fgraph
.
attach_feature
(
DestroyHandler
())
fgraph
.
attach_feature
(
DestroyHandler
())
def
attempt_scan_inplace
(
def
attempt_scan_inplace
(
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
output_indices
:
List
[
int
]
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
[
Scan
]
,
output_indices
:
List
[
int
]
)
->
Optional
[
Apply
]:
)
->
Optional
[
Apply
]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
...
@@ -953,7 +953,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -953,7 +953,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
Indices of the outputs to attempt to compute inplace
Indices of the outputs to attempt to compute inplace
"""
"""
op
:
Scan
=
cast
(
Scan
,
node
.
op
)
op
=
node
.
op
# inputs corresponding to sequences and n_steps
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
info
.
n_seqs
]
ls_begin
=
node
.
inputs
[:
1
+
op
.
info
.
n_seqs
]
...
@@ -1001,7 +1001,10 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1001,7 +1001,10 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op
.
destroy_map
=
destroy_map
new_op
.
destroy_map
=
destroy_map
# Do not call make_node for test_value
# Do not call make_node for test_value
new_outs
:
List
[
Variable
]
=
new_op
(
*
inputs
,
return_list
=
True
)
new_outs
=
new_op
(
*
inputs
,
return_list
=
True
)
assert
isinstance
(
new_outs
,
list
)
try
:
try
:
# TODO FIXME: We need to stop using this approach (i.e. attempt
# TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert
# in-place replacements and wait for downstream failures to revert
...
@@ -1015,7 +1018,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1015,7 +1018,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
remove
=
[
node
],
remove
=
[
node
],
reason
=
"scan_make_inplace"
,
reason
=
"scan_make_inplace"
,
)
)
return
new_outs
[
0
]
.
owner
return
cast
(
Apply
[
Scan
],
new_outs
[
0
]
.
owner
)
except
InconsistencyError
:
except
InconsistencyError
:
# Failed moving output to be computed inplace
# Failed moving output to be computed inplace
return
None
return
None
...
...
aesara/tensor/io.py
浏览文件 @
01d20497
...
@@ -82,7 +82,7 @@ def load(path, dtype, broadcastable, mmap_mode=None):
...
@@ -82,7 +82,7 @@ def load(path, dtype, broadcastable, mmap_mode=None):
Examples
Examples
--------
--------
>>> from aesara import *
>>> from aesara import *
>>> path = Variable(Generic())
>>> path = Variable(Generic()
, None
)
>>> x = tensor.load(path, 'int64', (False,))
>>> x = tensor.load(path, 'int64', (False,))
>>> y = x*2
>>> y = x*2
>>> fn = function([path], y)
>>> fn = function([path], y)
...
@@ -136,7 +136,7 @@ class MPIRecv(Op):
...
@@ -136,7 +136,7 @@ class MPIRecv(Op):
self
,
self
,
[],
[],
[
[
Variable
(
Generic
()),
Variable
(
Generic
()
,
None
),
tensor
(
self
.
dtype
,
shape
=
self
.
broadcastable
),
tensor
(
self
.
dtype
,
shape
=
self
.
broadcastable
),
],
],
)
)
...
@@ -222,7 +222,7 @@ class MPISend(Op):
...
@@ -222,7 +222,7 @@ class MPISend(Op):
self
.
tag
=
tag
self
.
tag
=
tag
def
make_node
(
self
,
data
):
def
make_node
(
self
,
data
):
return
Apply
(
self
,
[
data
],
[
Variable
(
Generic
()),
data
.
type
()])
return
Apply
(
self
,
[
data
],
[
Variable
(
Generic
()
,
None
),
data
.
type
()])
view_map
=
{
1
:
[
0
]}
view_map
=
{
1
:
[
0
]}
...
@@ -259,7 +259,7 @@ class MPISendWait(Op):
...
@@ -259,7 +259,7 @@ class MPISendWait(Op):
self
.
tag
=
tag
self
.
tag
=
tag
def
make_node
(
self
,
request
,
data
):
def
make_node
(
self
,
request
,
data
):
return
Apply
(
self
,
[
request
,
data
],
[
Variable
(
Generic
())])
return
Apply
(
self
,
[
request
,
data
],
[
Variable
(
Generic
()
,
None
)])
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
request
=
inp
[
0
]
request
=
inp
[
0
]
...
...
aesara/tensor/var.py
浏览文件 @
01d20497
...
@@ -3,13 +3,13 @@ import traceback as tb
...
@@ -3,13 +3,13 @@ import traceback as tb
import
warnings
import
warnings
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
Optional
from
typing
import
Optional
,
TypeVar
import
numpy
as
np
import
numpy
as
np
from
aesara
import
tensor
as
at
from
aesara
import
tensor
as
at
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
OptionalApplyType
,
Variable
from
aesara.graph.utils
import
MetaType
from
aesara.graph.utils
import
MetaType
from
aesara.scalar
import
ComplexError
,
IntegerDivisionError
from
aesara.scalar
import
ComplexError
,
IntegerDivisionError
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
...
@@ -18,6 +18,9 @@ from aesara.tensor.type import TensorType
...
@@ -18,6 +18,9 @@ from aesara.tensor.type import TensorType
from
aesara.tensor.utils
import
hash_from_ndarray
from
aesara.tensor.utils
import
hash_from_ndarray
_TensorTypeType
=
TypeVar
(
"_TensorTypeType"
,
bound
=
TensorType
)
class
_tensor_py_operators
:
class
_tensor_py_operators
:
def
__abs__
(
self
):
def
__abs__
(
self
):
return
at
.
math
.
abs
(
self
)
return
at
.
math
.
abs
(
self
)
...
@@ -811,14 +814,22 @@ class _tensor_py_operators:
...
@@ -811,14 +814,22 @@ class _tensor_py_operators:
return
at
.
extra_ops
.
compress
(
self
,
a
,
axis
=
axis
)
return
at
.
extra_ops
.
compress
(
self
,
a
,
axis
=
axis
)
class
TensorVariable
(
_tensor_py_operators
,
Variable
):
class
TensorVariable
(
_tensor_py_operators
,
Variable
[
_TensorTypeType
,
OptionalApplyType
]
):
"""
"""
Subclass to add the tensor operators to the basic `Variable` class.
Subclass to add the tensor operators to the basic `Variable` class.
"""
"""
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
def
__init__
(
super
()
.
__init__
(
type
,
owner
=
owner
,
index
=
index
,
name
=
name
)
self
,
type
:
_TensorTypeType
,
owner
:
OptionalApplyType
,
index
=
None
,
name
=
None
,
):
super
()
.
__init__
(
type
,
owner
,
index
=
index
,
name
=
name
)
if
config
.
warn_float64
!=
"ignore"
and
type
.
dtype
==
"float64"
:
if
config
.
warn_float64
!=
"ignore"
and
type
.
dtype
==
"float64"
:
msg
=
(
msg
=
(
"You are creating a TensorVariable "
"You are creating a TensorVariable "
...
@@ -979,10 +990,10 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]:
...
@@ -979,10 +990,10 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]:
return
None
return
None
class
TensorConstant
(
TensorVariable
,
Constant
):
class
TensorConstant
(
TensorVariable
,
Constant
[
_TensorTypeType
]
):
"""Subclass to add the tensor operators to the basic `Constant` class."""
"""Subclass to add the tensor operators to the basic `Constant` class."""
def
__init__
(
self
,
type
,
data
,
name
=
None
):
def
__init__
(
self
,
type
:
_TensorTypeType
,
data
,
name
=
None
):
data_shape
=
np
.
shape
(
data
)
data_shape
=
np
.
shape
(
data
)
if
len
(
data_shape
)
!=
type
.
ndim
or
any
(
if
len
(
data_shape
)
!=
type
.
ndim
or
any
(
...
...
doc/sandbox/how_to_make_ops.rst
浏览文件 @
01d20497
...
@@ -65,8 +65,8 @@ Example:
...
@@ -65,8 +65,8 @@ Example:
#...
#...
def make_node(self, x, y):
def make_node(self, x, y):
# note 1: constant, int64 and ScalarType are defined in aesara.scalar
# note 1: constant, int64 and ScalarType are defined in aesara.scalar
# note 2: constant(x) is equivalent to Constant(type
= int64, data =
x)
# note 2: constant(x) is equivalent to Constant(type
=int64, data=
x)
# note 3: the call int64() is equivalent to Variable(type
= int64) or Variable(type = ScalarType(dtype = 'int64')
)
# note 3: the call int64() is equivalent to Variable(type
=int64, None) or Variable(type=ScalarType(dtype = 'int64'), None
)
if isinstance(x, int):
if isinstance(x, int):
x = constant(x)
x = constant(x)
elif not isinstance(x, Variable) or not x.type == int64:
elif not isinstance(x, Variable) or not x.type == int64:
...
...
tests/graph/test_basic.py
浏览文件 @
01d20497
...
@@ -339,7 +339,7 @@ class TestAutoName:
...
@@ -339,7 +339,7 @@ class TestAutoName:
autoname_id
=
next
(
Variable
.
__count__
)
autoname_id
=
next
(
Variable
.
__count__
)
Variable
.
__count__
=
count
(
autoname_id
)
Variable
.
__count__
=
count
(
autoname_id
)
r1
=
TensorType
(
dtype
=
"int32"
,
shape
=
())(
"myvar"
)
r1
=
TensorType
(
dtype
=
"int32"
,
shape
=
())(
"myvar"
)
r2
=
TensorVariable
(
TensorType
(
dtype
=
"int32"
,
shape
=
()))
r2
=
TensorVariable
(
TensorType
(
dtype
=
"int32"
,
shape
=
())
,
None
)
r3
=
shared
(
np
.
random
.
standard_normal
((
3
,
4
)))
r3
=
shared
(
np
.
random
.
standard_normal
((
3
,
4
)))
assert
r1
.
auto_name
==
"auto_"
+
str
(
autoname_id
)
assert
r1
.
auto_name
==
"auto_"
+
str
(
autoname_id
)
assert
r2
.
auto_name
==
"auto_"
+
str
(
autoname_id
+
1
)
assert
r2
.
auto_name
==
"auto_"
+
str
(
autoname_id
+
1
)
...
...
tests/tensor/test_io.py
浏览文件 @
01d20497
...
@@ -17,7 +17,7 @@ class TestLoadTensor:
...
@@ -17,7 +17,7 @@ class TestLoadTensor:
np
.
save
(
self
.
filename
,
self
.
data
)
np
.
save
(
self
.
filename
,
self
.
data
)
def
test_basic
(
self
):
def
test_basic
(
self
):
path
=
Variable
(
Generic
())
path
=
Variable
(
Generic
()
,
None
)
# Not specifying mmap_mode defaults to None, and the data is
# Not specifying mmap_mode defaults to None, and the data is
# copied into main memory
# copied into main memory
x
=
load
(
path
,
"int32"
,
(
False
,))
x
=
load
(
path
,
"int32"
,
(
False
,))
...
@@ -29,13 +29,13 @@ class TestLoadTensor:
...
@@ -29,13 +29,13 @@ class TestLoadTensor:
# Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei
# Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei
# the output array may be modified inplace, and that should not
# the output array may be modified inplace, and that should not
# modify the original file.
# modify the original file.
path
=
Variable
(
Generic
())
path
=
Variable
(
Generic
()
,
None
)
for
mmap_mode
in
(
"r+"
,
"r"
,
"w+"
,
"toto"
):
for
mmap_mode
in
(
"r+"
,
"r"
,
"w+"
,
"toto"
):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
load
(
path
,
"int32"
,
(
False
,),
mmap_mode
)
load
(
path
,
"int32"
,
(
False
,),
mmap_mode
)
def
test1
(
self
):
def
test1
(
self
):
path
=
Variable
(
Generic
())
path
=
Variable
(
Generic
()
,
None
)
# 'c' means "copy-on-write", which allow the array to be overwritten
# 'c' means "copy-on-write", which allow the array to be overwritten
# by an inplace Op in the graph, without modifying the underlying
# by an inplace Op in the graph, without modifying the underlying
# file.
# file.
...
@@ -48,7 +48,7 @@ class TestLoadTensor:
...
@@ -48,7 +48,7 @@ class TestLoadTensor:
assert
(
fn
(
self
.
filename
)
==
(
self
.
data
**
2
)
.
sum
())
.
all
()
assert
(
fn
(
self
.
filename
)
==
(
self
.
data
**
2
)
.
sum
())
.
all
()
def
test_memmap
(
self
):
def
test_memmap
(
self
):
path
=
Variable
(
Generic
())
path
=
Variable
(
Generic
()
,
None
)
x
=
load
(
path
,
"int32"
,
(
False
,),
mmap_mode
=
"c"
)
x
=
load
(
path
,
"int32"
,
(
False
,),
mmap_mode
=
"c"
)
fn
=
function
([
path
],
x
)
fn
=
function
([
path
],
x
)
assert
type
(
fn
(
self
.
filename
))
==
np
.
core
.
memmap
assert
type
(
fn
(
self
.
filename
))
==
np
.
core
.
memmap
...
...
tests/tensor/test_shape.py
浏览文件 @
01d20497
...
@@ -63,7 +63,7 @@ def test_shape_basic():
...
@@ -63,7 +63,7 @@ def test_shape_basic():
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
s
=
shape
(
Variable
(
MyType
()))
s
=
shape
(
Variable
(
MyType
()
,
None
))
assert
s
.
type
.
broadcastable
==
(
False
,)
assert
s
.
type
.
broadcastable
==
(
False
,)
s
=
shape
(
np
.
array
(
1
))
s
=
shape
(
np
.
array
(
1
))
...
@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
rng
=
np
.
random
.
default_rng
(
3453
)
adtens4
=
dtensor4
()
adtens4
=
dtensor4
()
aivec
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,)))
aivec
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,))
,
None
)
aivec_val
=
[
3
,
4
,
2
,
5
]
aivec_val
=
[
3
,
4
,
2
,
5
]
adtens4_val
=
rng
.
random
(
aivec_val
)
adtens4_val
=
rng
.
random
(
aivec_val
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
...
...
tests/tensor/test_var.py
浏览文件 @
01d20497
...
@@ -234,7 +234,7 @@ def test__getitem__newaxis(x, indices, new_order):
...
@@ -234,7 +234,7 @@ def test__getitem__newaxis(x, indices, new_order):
def
test_fixed_shape_variable_basic
():
def
test_fixed_shape_variable_basic
():
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,)))
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,))
,
None
)
assert
isinstance
(
x
.
shape
,
Constant
)
assert
isinstance
(
x
.
shape
,
Constant
)
assert
np
.
array_equal
(
x
.
shape
.
data
,
(
4
,))
assert
np
.
array_equal
(
x
.
shape
.
data
,
(
4
,))
...
@@ -246,11 +246,11 @@ def test_fixed_shape_variable_basic():
...
@@ -246,11 +246,11 @@ def test_fixed_shape_variable_basic():
def
test_get_vector_length
():
def
test_get_vector_length
():
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,)))
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
4
,))
,
None
)
res
=
get_vector_length
(
x
)
res
=
get_vector_length
(
x
)
assert
res
==
4
assert
res
==
4
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
None
,)))
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
None
,))
,
None
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
get_vector_length
(
x
)
get_vector_length
(
x
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论