Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fdf2a23a
提交
fdf2a23a
authored
2月 14, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
3月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make sparse tensor types extend the existing tensor types
上级
31ab8fdd
隐藏空白字符变更
内嵌
并排
正在显示
14 个修改的文件
包含
294 行增加
和
72 行删除
+294
-72
basic.py
aesara/sparse/basic.py
+18
-8
opt.py
aesara/sparse/opt.py
+14
-2
type.py
aesara/sparse/type.py
+51
-38
basic.py
aesara/tensor/basic.py
+7
-2
basic_opt.py
aesara/tensor/basic_opt.py
+8
-2
blas.py
aesara/tensor/blas.py
+37
-3
math.py
aesara/tensor/math.py
+6
-0
shape.py
aesara/tensor/shape.py
+1
-1
type.py
aesara/tensor/type.py
+19
-3
var.py
aesara/tensor/var.py
+31
-0
test_basic.py
tests/sparse/test_basic.py
+58
-8
test_type.py
tests/sparse/test_type.py
+22
-0
test_var.py
tests/tensor/test_var.py
+18
-1
test_ifelse.py
tests/test_ifelse.py
+4
-4
没有找到文件。
aesara/sparse/basic.py
浏览文件 @
fdf2a23a
...
@@ -49,6 +49,7 @@ from aesara.tensor.type import TensorType
...
@@ -49,6 +49,7 @@ from aesara.tensor.type import TensorType
from
aesara.tensor.type
import
continuous_dtypes
as
tensor_continuous_dtypes
from
aesara.tensor.type
import
continuous_dtypes
as
tensor_continuous_dtypes
from
aesara.tensor.type
import
discrete_dtypes
as
tensor_discrete_dtypes
from
aesara.tensor.type
import
discrete_dtypes
as
tensor_discrete_dtypes
from
aesara.tensor.type
import
iscalar
,
ivector
,
scalar
,
tensor
,
vector
from
aesara.tensor.type
import
iscalar
,
ivector
,
scalar
,
tensor
,
vector
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
,
_tensor_py_operators
sparse_formats
=
[
"csc"
,
"csr"
]
sparse_formats
=
[
"csc"
,
"csr"
]
...
@@ -126,8 +127,7 @@ def _is_dense(x):
...
@@ -126,8 +127,7 @@ def _is_dense(x):
return
isinstance
(
x
,
np
.
ndarray
)
return
isinstance
(
x
,
np
.
ndarray
)
# Wrapper type
def
as_sparse_variable
(
x
,
name
=
None
,
ndim
=
None
,
**
kwargs
):
def
as_sparse_variable
(
x
,
name
=
None
):
"""
"""
Wrapper around SparseVariable constructor to construct
Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and
a Variable with a sparse matrix with the same dtype and
...
@@ -250,7 +250,7 @@ def sp_zeros_like(x):
...
@@ -250,7 +250,7 @@ def sp_zeros_like(x):
)
)
class
_sparse_py_operators
:
class
_sparse_py_operators
(
_tensor_py_operators
)
:
T
=
property
(
T
=
property
(
lambda
self
:
transpose
(
self
),
doc
=
"Return aliased transpose of self (read-only)"
lambda
self
:
transpose
(
self
),
doc
=
"Return aliased transpose of self (read-only)"
)
)
...
@@ -361,8 +361,7 @@ class _sparse_py_operators:
...
@@ -361,8 +361,7 @@ class _sparse_py_operators:
return
ret
return
ret
class
SparseVariable
(
_sparse_py_operators
,
Variable
):
class
SparseVariable
(
_sparse_py_operators
,
TensorVariable
):
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -395,8 +394,7 @@ class SparseConstantSignature(tuple):
...
@@ -395,8 +394,7 @@ class SparseConstantSignature(tuple):
return
hash_from_sparse
(
d
)
return
hash_from_sparse
(
d
)
class
SparseConstant
(
Constant
,
_sparse_py_operators
):
class
SparseConstant
(
TensorConstant
,
_sparse_py_operators
):
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
def
signature
(
self
):
def
signature
(
self
):
...
@@ -448,7 +446,7 @@ csc_fmatrix = SparseType(format="csc", dtype="float32")
...
@@ -448,7 +446,7 @@ csc_fmatrix = SparseType(format="csc", dtype="float32")
csr_fmatrix
=
SparseType
(
format
=
"csr"
,
dtype
=
"float32"
)
csr_fmatrix
=
SparseType
(
format
=
"csr"
,
dtype
=
"float32"
)
bsr_fmatrix
=
SparseType
(
format
=
"bsr"
,
dtype
=
"float32"
)
bsr_fmatrix
=
SparseType
(
format
=
"bsr"
,
dtype
=
"float32"
)
all_dtypes
=
SparseType
.
dtype_set
all_dtypes
=
list
(
SparseType
.
dtype_specs_map
.
keys
())
complex_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
7
]
==
"complex"
]
complex_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
7
]
==
"complex"
]
float_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
5
]
==
"float"
]
float_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
5
]
==
"float"
]
int_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
3
]
==
"int"
]
int_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
3
]
==
"int"
]
...
@@ -926,6 +924,12 @@ class DenseFromSparse(Op):
...
@@ -926,6 +924,12 @@ class DenseFromSparse(Op):
def
__str__
(
self
):
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"
return
f
"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"
def
__call__
(
self
,
x
):
if
not
isinstance
(
x
.
type
,
SparseType
):
return
x
return
super
()
.
__call__
(
x
)
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
return
Apply
(
return
Apply
(
...
@@ -1003,6 +1007,12 @@ class SparseFromDense(Op):
...
@@ -1003,6 +1007,12 @@ class SparseFromDense(Op):
def
__str__
(
self
):
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{{self.format}}}"
return
f
"{self.__class__.__name__}{{{self.format}}}"
def
__call__
(
self
,
x
):
if
isinstance
(
x
.
type
,
SparseType
):
return
x
return
super
()
.
__call__
(
x
)
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
at
.
as_tensor_variable
(
x
)
x
=
at
.
as_tensor_variable
(
x
)
if
x
.
ndim
>
2
:
if
x
.
ndim
>
2
:
...
...
aesara/sparse/opt.py
浏览文件 @
fdf2a23a
...
@@ -23,6 +23,7 @@ from aesara.tensor import blas
...
@@ -23,6 +23,7 @@ from aesara.tensor import blas
from
aesara.tensor.basic
import
as_tensor_variable
,
cast
,
patternbroadcast
from
aesara.tensor.basic
import
as_tensor_variable
,
cast
,
patternbroadcast
from
aesara.tensor.basic_opt
import
register_canonicalize
,
register_specialize
from
aesara.tensor.basic_opt
import
register_canonicalize
,
register_specialize
from
aesara.tensor.math
import
mul
,
neg
,
sub
from
aesara.tensor.math
import
mul
,
neg
,
sub
from
aesara.tensor.shape
import
shape
,
specify_shape
from
aesara.tensor.type
import
TensorType
,
tensor
from
aesara.tensor.type
import
TensorType
,
tensor
...
@@ -2070,8 +2071,19 @@ def local_sampling_dot_csr(fgraph, node):
...
@@ -2070,8 +2071,19 @@ def local_sampling_dot_csr(fgraph, node):
z_data
,
z_ind
,
z_ptr
=
sampling_dot_csr
(
z_data
,
z_ind
,
z_ptr
=
sampling_dot_csr
(
x
,
y
,
p_data
,
p_ind
,
p_ptr
,
p_shape
[
1
]
x
,
y
,
p_data
,
p_ind
,
p_ptr
,
p_shape
[
1
]
)
)
# This is a hack that works around some missing `Type`-related
return
[
sparse
.
CSR
(
z_data
,
z_ind
,
z_ptr
,
p_shape
)]
# static shape narrowing. More specifically,
# `TensorType.convert_variable` currently won't combine the static
# shape information from `old_out.type` and `new_out.type`, only
# the broadcast patterns, and, since `CSR.make_node` doesn't do
# that either, we use `specify_shape` to produce an output `Type`
# with the same level of static shape information as the original
# `old_out`.
old_out
=
node
.
outputs
[
0
]
new_out
=
specify_shape
(
sparse
.
CSR
(
z_data
,
z_ind
,
z_ptr
,
p_shape
),
shape
(
old_out
)
)
return
[
new_out
]
return
False
return
False
...
...
aesara/sparse/type.py
浏览文件 @
fdf2a23a
...
@@ -2,7 +2,8 @@ import numpy as np
...
@@ -2,7 +2,8 @@ import numpy as np
import
scipy.sparse
import
scipy.sparse
import
aesara
import
aesara
from
aesara.graph.type
import
HasDataType
,
Type
from
aesara.graph.type
import
HasDataType
from
aesara.tensor.type
import
TensorType
def
_is_sparse
(
x
):
def
_is_sparse
(
x
):
...
@@ -24,7 +25,7 @@ def _is_sparse(x):
...
@@ -24,7 +25,7 @@ def _is_sparse(x):
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
class
SparseType
(
Type
,
HasDataType
):
class
SparseType
(
T
ensorT
ype
,
HasDataType
):
"""
"""
Fundamental way to create a sparse node.
Fundamental way to create a sparse node.
...
@@ -52,19 +53,19 @@ class SparseType(Type, HasDataType):
...
@@ -52,19 +53,19 @@ class SparseType(Type, HasDataType):
"csc"
:
scipy
.
sparse
.
csc_matrix
,
"csc"
:
scipy
.
sparse
.
csc_matrix
,
"bsr"
:
scipy
.
sparse
.
bsr_matrix
,
"bsr"
:
scipy
.
sparse
.
bsr_matrix
,
}
}
dtype_s
et
=
{
dtype_s
pecs_map
=
{
"
int8"
,
"
float32"
:
(
float
,
"npy_float32"
,
"NPY_FLOAT32"
)
,
"
int16"
,
"
float64"
:
(
float
,
"npy_float64"
,
"NPY_FLOAT64"
)
,
"
int32"
,
"
uint8"
:
(
int
,
"npy_uint8"
,
"NPY_UINT8"
)
,
"int
64"
,
"int
8"
:
(
int
,
"npy_int8"
,
"NPY_INT8"
)
,
"
float32"
,
"
uint16"
:
(
int
,
"npy_uint16"
,
"NPY_UINT16"
)
,
"
uint8"
,
"
int16"
:
(
int
,
"npy_int16"
,
"NPY_INT16"
)
,
"uint
16"
,
"uint
32"
:
(
int
,
"npy_uint32"
,
"NPY_UINT32"
)
,
"
uint32"
,
"
int32"
:
(
int
,
"npy_int32"
,
"NPY_INT32"
)
,
"uint64"
,
"uint64"
:
(
int
,
"npy_uint64"
,
"NPY_UINT64"
)
,
"
float64"
,
"
int64"
:
(
int
,
"npy_int64"
,
"NPY_INT64"
)
,
"complex
64"
,
"complex
128"
:
(
complex
,
"aesara_complex128"
,
"NPY_COMPLEX128"
)
,
"complex
128"
,
"complex
64"
:
(
complex
,
"aesara_complex64"
,
"NPY_COMPLEX64"
)
,
}
}
ndim
=
2
ndim
=
2
...
@@ -72,28 +73,25 @@ class SparseType(Type, HasDataType):
...
@@ -72,28 +73,25 @@ class SparseType(Type, HasDataType):
variable_type
=
None
variable_type
=
None
Constant
=
None
Constant
=
None
def
__init__
(
self
,
format
,
dtype
,
shape
=
None
):
def
__init__
(
self
,
format
,
dtype
,
shape
=
None
,
broadcastable
=
None
,
name
=
None
):
dtype
=
str
(
dtype
)
if
dtype
in
self
.
dtype_set
:
self
.
dtype
=
dtype
else
:
raise
NotImplementedError
(
f
'unsupported dtype "{dtype}" not in list'
,
list
(
self
.
dtype_set
)
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
(
None
,
None
)
shape
=
(
None
,
None
)
self
.
shape
=
shape
self
.
shape
=
shape
assert
isinstance
(
format
,
str
)
if
not
isinstance
(
format
,
str
):
raise
TypeError
(
"The sparse format parameter must be a string"
)
if
format
in
self
.
format_cls
:
if
format
in
self
.
format_cls
:
self
.
format
=
format
self
.
format
=
format
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
'unsupported format "{format}" not in list'
,
f
'unsupported format "{format}" not in list'
,
list
(
self
.
format_cls
.
keys
()),
)
)
if
broadcastable
is
None
:
broadcastable
=
[
False
,
False
]
super
()
.
__init__
(
dtype
,
shape
,
name
=
name
)
def
clone
(
self
,
format
=
None
,
dtype
=
None
,
shape
=
None
,
**
kwargs
):
def
clone
(
self
,
format
=
None
,
dtype
=
None
,
shape
=
None
,
**
kwargs
):
if
format
is
None
:
if
format
is
None
:
...
@@ -153,21 +151,11 @@ class SparseType(Type, HasDataType):
...
@@ -153,21 +151,11 @@ class SparseType(Type, HasDataType):
def
make_variable
(
self
,
name
=
None
):
def
make_variable
(
self
,
name
=
None
):
return
self
.
variable_type
(
self
,
name
=
name
)
return
self
.
variable_type
(
self
,
name
=
name
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
other
.
dtype
==
self
.
dtype
and
other
.
format
==
self
.
format
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
self
.
dtype
)
^
hash
(
self
.
format
)
return
super
()
.
__hash__
()
^
hash
(
self
.
format
)
def
__str__
(
self
):
return
f
"Sparse[{self.dtype}, {self.format}]"
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"Sparse
[{self.dtype}, {self.format}]
"
return
f
"Sparse
({self.dtype}, {self.shape}, {self.format})
"
def
values_eq_approx
(
self
,
a
,
b
,
eps
=
1e-6
):
def
values_eq_approx
(
self
,
a
,
b
,
eps
=
1e-6
):
# WARNING: equality comparison of sparse matrices is not fast or easy
# WARNING: equality comparison of sparse matrices is not fast or easy
...
@@ -210,6 +198,31 @@ class SparseType(Type, HasDataType):
...
@@ -210,6 +198,31 @@ class SparseType(Type, HasDataType):
+
(
shape_info
[
2
]
+
shape_info
[
3
])
*
np
.
dtype
(
"int32"
)
.
itemsize
+
(
shape_info
[
2
]
+
shape_info
[
3
])
*
np
.
dtype
(
"int32"
)
.
itemsize
)
)
def
value_zeros
(
self
,
shape
):
matrix_constructor
=
self
.
format_cls
.
get
(
self
.
format
)
if
matrix_constructor
is
None
:
raise
ValueError
(
f
"Sparse matrix type {self.format} not found in SciPy"
)
return
matrix_constructor
(
shape
,
dtype
=
self
.
dtype
)
def
__eq__
(
self
,
other
):
res
=
super
()
.
__eq__
(
other
)
if
isinstance
(
res
,
bool
):
return
res
and
other
.
format
==
self
.
format
return
res
def
is_super
(
self
,
otype
):
if
not
super
()
.
is_super
(
otype
):
return
False
if
self
.
format
==
otype
.
format
:
return
True
return
False
# Register SparseType's C code for ViewOp.
# Register SparseType's C code for ViewOp.
aesara
.
compile
.
register_view_op_c_code
(
aesara
.
compile
.
register_view_op_c_code
(
...
...
aesara/tensor/basic.py
浏览文件 @
fdf2a23a
...
@@ -313,8 +313,13 @@ def get_scalar_constant_value(
...
@@ -313,8 +313,13 @@ def get_scalar_constant_value(
return
np
.
array
(
data
.
item
(),
dtype
=
v
.
dtype
)
return
np
.
array
(
data
.
item
(),
dtype
=
v
.
dtype
)
except
ValueError
:
except
ValueError
:
raise
NotScalarConstantError
()
raise
NotScalarConstantError
()
else
:
return
data
from
aesara.sparse.type
import
SparseType
if
isinstance
(
v
.
type
,
SparseType
):
raise
NotScalarConstantError
()
return
data
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
max_recur
-=
1
max_recur
-=
1
...
...
aesara/tensor/basic_opt.py
浏览文件 @
fdf2a23a
...
@@ -78,7 +78,12 @@ from aesara.tensor.math import eq
...
@@ -78,7 +78,12 @@ from aesara.tensor.math import eq
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
shape_padleft
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
shape_padleft
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
from
aesara.tensor.type
import
(
DenseTensorType
,
TensorType
,
discrete_dtypes
,
integer_dtypes
,
)
from
aesara.tensor.var
import
TensorConstant
from
aesara.tensor.var
import
TensorConstant
from
aesara.utils
import
NoDuplicateOptWarningFilter
from
aesara.utils
import
NoDuplicateOptWarningFilter
...
@@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node):
...
@@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node):
# TODO: `Type` itself should provide an interface for constructing
# TODO: `Type` itself should provide an interface for constructing
# instances appropriate for a given constant.
# instances appropriate for a given constant.
if
isinstance
(
output
.
type
,
TensorType
):
# TODO: Add handling for sparse types.
if
isinstance
(
output
.
type
,
DenseTensorType
):
output_type
=
TensorType
(
output_type
=
TensorType
(
output
.
type
.
dtype
,
output
.
type
.
dtype
,
tuple
(
s
==
1
for
s
in
data
.
shape
),
tuple
(
s
==
1
for
s
in
data
.
shape
),
...
...
aesara/tensor/blas.py
浏览文件 @
fdf2a23a
...
@@ -167,7 +167,12 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
...
@@ -167,7 +167,12 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.type
import
integer_dtypes
,
tensor
,
values_eq_approx_remove_inf_nan
from
aesara.tensor.type
import
(
DenseTensorType
,
integer_dtypes
,
tensor
,
values_eq_approx_remove_inf_nan
,
)
from
aesara.utils
import
memoize
from
aesara.utils
import
memoize
...
@@ -264,7 +269,13 @@ class Gemv(Op):
...
@@ -264,7 +269,13 @@ class Gemv(Op):
raise
TypeError
(
"gemv requires vector for x"
,
x
.
type
)
raise
TypeError
(
"gemv requires vector for x"
,
x
.
type
)
if
y
.
ndim
!=
1
:
if
y
.
ndim
!=
1
:
raise
TypeError
(
"gemv requires vector for y"
,
y
.
type
)
raise
TypeError
(
"gemv requires vector for y"
,
y
.
type
)
return
Apply
(
self
,
[
y
,
alpha
,
A
,
x
,
beta
],
[
y
.
type
()])
inputs
=
[
y
,
alpha
,
A
,
x
,
beta
]
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
inputs
):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
return
Apply
(
self
,
inputs
,
[
y
.
type
()])
def
perform
(
self
,
node
,
inputs
,
out_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
out_storage
,
params
=
None
):
y
,
alpha
,
A
,
x
,
beta
=
inputs
y
,
alpha
,
A
,
x
,
beta
=
inputs
...
@@ -361,7 +372,12 @@ class Ger(Op):
...
@@ -361,7 +372,12 @@ class Ger(Op):
if
x
.
dtype
not
in
(
"float32"
,
"float64"
,
"complex64"
,
"complex128"
):
if
x
.
dtype
not
in
(
"float32"
,
"float64"
,
"complex64"
,
"complex128"
):
raise
TypeError
(
"only float and complex types supported"
,
x
.
dtype
)
raise
TypeError
(
"only float and complex types supported"
,
x
.
dtype
)
return
Apply
(
self
,
[
A
,
alpha
,
x
,
y
],
[
A
.
type
()])
inputs
=
[
A
,
alpha
,
x
,
y
]
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
inputs
):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
def
perform
(
self
,
node
,
inp
,
out
,
params
=
None
):
cA
,
calpha
,
cx
,
cy
=
inp
cA
,
calpha
,
cx
,
cy
=
inp
...
@@ -899,6 +915,10 @@ class Gemm(GemmRelated):
...
@@ -899,6 +915,10 @@ class Gemm(GemmRelated):
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
at
.
as_tensor_variable
,
inputs
))
inputs
=
list
(
map
(
at
.
as_tensor_variable
,
inputs
))
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
inputs
):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
if
len
(
inputs
)
!=
5
:
if
len
(
inputs
)
!=
5
:
raise
TypeError
(
raise
TypeError
(
f
"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
f
"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
...
@@ -1580,6 +1600,10 @@ class Dot22(GemmRelated):
...
@@ -1580,6 +1600,10 @@ class Dot22(GemmRelated):
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
=
at
.
as_tensor_variable
(
x
)
x
=
at
.
as_tensor_variable
(
x
)
y
=
at
.
as_tensor_variable
(
y
)
y
=
at
.
as_tensor_variable
(
y
)
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
(
x
,
y
)):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
dtypes
=
(
"float16"
,
"float32"
,
"float64"
,
"complex64"
,
"complex128"
)
dtypes
=
(
"float16"
,
"float32"
,
"float64"
,
"complex64"
,
"complex128"
)
if
x
.
type
.
ndim
!=
2
or
x
.
type
.
dtype
not
in
dtypes
:
if
x
.
type
.
ndim
!=
2
or
x
.
type
.
dtype
not
in
dtypes
:
raise
TypeError
(
x
)
raise
TypeError
(
x
)
...
@@ -1665,6 +1689,9 @@ def local_dot_to_dot22(fgraph, node):
...
@@ -1665,6 +1689,9 @@ def local_dot_to_dot22(fgraph, node):
if
not
isinstance
(
node
.
op
,
Dot
):
if
not
isinstance
(
node
.
op
,
Dot
):
return
return
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
node
.
inputs
):
return
False
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
if
y
.
type
.
dtype
!=
x
.
type
.
dtype
:
if
y
.
type
.
dtype
!=
x
.
type
.
dtype
:
# TODO: upcast one so the types match
# TODO: upcast one so the types match
...
@@ -1869,6 +1896,10 @@ class Dot22Scalar(GemmRelated):
...
@@ -1869,6 +1896,10 @@ class Dot22Scalar(GemmRelated):
check_input
=
False
check_input
=
False
def
make_node
(
self
,
x
,
y
,
a
):
def
make_node
(
self
,
x
,
y
,
a
):
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
(
x
,
y
,
a
)):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
if
a
.
ndim
!=
0
:
if
a
.
ndim
!=
0
:
raise
TypeError
(
Gemm
.
E_scalar
,
a
)
raise
TypeError
(
Gemm
.
E_scalar
,
a
)
if
x
.
ndim
!=
2
:
if
x
.
ndim
!=
2
:
...
@@ -2089,6 +2120,9 @@ class BatchedDot(COp):
...
@@ -2089,6 +2120,9 @@ class BatchedDot(COp):
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
at
.
as_tensor_variable
,
inputs
))
inputs
=
list
(
map
(
at
.
as_tensor_variable
,
inputs
))
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
inputs
):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
if
len
(
inputs
)
!=
2
:
if
len
(
inputs
)
!=
2
:
raise
TypeError
(
f
"Two arguments required, but {len(inputs)} given."
)
raise
TypeError
(
f
"Two arguments required, but {len(inputs)} given."
)
if
inputs
[
0
]
.
ndim
not
in
(
2
,
3
):
if
inputs
[
0
]
.
ndim
not
in
(
2
,
3
):
...
...
aesara/tensor/math.py
浏览文件 @
fdf2a23a
...
@@ -34,6 +34,7 @@ from aesara.tensor.elemwise import (
...
@@ -34,6 +34,7 @@ from aesara.tensor.elemwise import (
)
)
from
aesara.tensor.shape
import
shape
from
aesara.tensor.shape
import
shape
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
DenseTensorType
,
complex_dtypes
,
complex_dtypes
,
continuous_dtypes
,
continuous_dtypes
,
discrete_dtypes
,
discrete_dtypes
,
...
@@ -2076,6 +2077,11 @@ def dense_dot(a, b):
...
@@ -2076,6 +2077,11 @@ def dense_dot(a, b):
"""
"""
a
,
b
=
as_tensor_variable
(
a
),
as_tensor_variable
(
b
)
a
,
b
=
as_tensor_variable
(
a
),
as_tensor_variable
(
b
)
if
not
isinstance
(
a
.
type
,
DenseTensorType
)
or
not
isinstance
(
b
.
type
,
DenseTensorType
):
raise
TypeError
(
"The dense dot product is only supported for dense types"
)
if
a
.
ndim
==
0
or
b
.
ndim
==
0
:
if
a
.
ndim
==
0
or
b
.
ndim
==
0
:
return
a
*
b
return
a
*
b
elif
a
.
ndim
>
2
or
b
.
ndim
>
2
:
elif
a
.
ndim
>
2
or
b
.
ndim
>
2
:
...
...
aesara/tensor/shape.py
浏览文件 @
fdf2a23a
...
@@ -431,7 +431,7 @@ class SpecifyShape(COp):
...
@@ -431,7 +431,7 @@ class SpecifyShape(COp):
)
)
if
isinstance
(
x
.
type
,
TensorType
)
and
all
(
isinstance
(
s
,
Number
)
for
s
in
shape
):
if
isinstance
(
x
.
type
,
TensorType
)
and
all
(
isinstance
(
s
,
Number
)
for
s
in
shape
):
out_var
=
TensorType
(
x
.
type
.
dtype
,
shape
)()
out_var
=
x
.
type
.
clone
(
shape
=
shape
)()
else
:
else
:
out_var
=
x
.
type
()
out_var
=
x
.
type
()
...
...
aesara/tensor/type.py
浏览文件 @
fdf2a23a
import
logging
import
logging
import
warnings
import
warnings
from
typing
import
Iterable
,
Optional
,
Union
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -9,6 +9,7 @@ from aesara import scalar as aes
...
@@ -9,6 +9,7 @@ from aesara import scalar as aes
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Variable
from
aesara.graph.basic
import
Variable
from
aesara.graph.type
import
HasDataType
from
aesara.graph.type
import
HasDataType
from
aesara.graph.utils
import
MetaType
from
aesara.link.c.type
import
CType
from
aesara.link.c.type
import
CType
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.utils
import
apply_across_args
from
aesara.utils
import
apply_across_args
...
@@ -50,8 +51,9 @@ dtype_specs_map = {
...
@@ -50,8 +51,9 @@ dtype_specs_map = {
class
TensorType
(
CType
,
HasDataType
):
class
TensorType
(
CType
,
HasDataType
):
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
__props__
=
(
"dtype"
,
"shape"
)
__props__
:
Tuple
[
str
,
...
]
=
(
"dtype"
,
"shape"
)
dtype_specs_map
=
dtype_specs_map
context_name
=
"cpu"
context_name
=
"cpu"
filter_checks_isfinite
=
False
filter_checks_isfinite
=
False
"""
"""
...
@@ -271,7 +273,7 @@ class TensorType(CType, HasDataType):
...
@@ -271,7 +273,7 @@ class TensorType(CType, HasDataType):
"""
"""
try
:
try
:
return
dtype_specs_map
[
self
.
dtype
]
return
self
.
dtype_specs_map
[
self
.
dtype
]
except
KeyError
:
except
KeyError
:
raise
TypeError
(
raise
TypeError
(
f
"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
f
"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
...
@@ -613,6 +615,20 @@ class TensorType(CType, HasDataType):
...
@@ -613,6 +615,20 @@ class TensorType(CType, HasDataType):
return
()
return
()
class
DenseTypeMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorType
or
isinstance
(
o
,
DenseTypeMeta
):
return
True
return
False
class
DenseTensorType
(
TensorType
,
metaclass
=
DenseTypeMeta
):
r"""A `Type` for dense tensors.
Instances of this class and `TensorType`\s are considered dense `Type`\s.
"""
def
values_eq_approx
(
def
values_eq_approx
(
a
,
b
,
allow_remove_inf
=
False
,
allow_remove_nan
=
False
,
rtol
=
None
,
atol
=
None
a
,
b
,
allow_remove_inf
=
False
,
allow_remove_nan
=
False
,
rtol
=
None
,
atol
=
None
):
):
...
...
aesara/tensor/var.py
浏览文件 @
fdf2a23a
...
@@ -10,6 +10,7 @@ import numpy as np
...
@@ -10,6 +10,7 @@ import numpy as np
from
aesara
import
tensor
as
at
from
aesara
import
tensor
as
at
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.utils
import
MetaType
from
aesara.scalar
import
ComplexError
,
IntegerDivisionError
from
aesara.scalar
import
ComplexError
,
IntegerDivisionError
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
from
aesara.tensor.exceptions
import
AdvancedIndexingError
from
aesara.tensor.exceptions
import
AdvancedIndexingError
...
@@ -1040,3 +1041,33 @@ class TensorConstant(TensorVariable, Constant):
...
@@ -1040,3 +1041,33 @@ class TensorConstant(TensorVariable, Constant):
TensorType
.
constant_type
=
TensorConstant
TensorType
.
constant_type
=
TensorConstant
class
DenseVariableMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorVariable
or
isinstance
(
o
,
DenseVariableMeta
):
return
True
return
False
class
DenseTensorVariable
(
TensorType
,
metaclass
=
DenseVariableMeta
):
r"""A `Variable` for dense tensors.
Instances of this class and `TensorVariable`\s are considered dense
`Variable`\s.
"""
class
DenseConstantMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorConstant
or
isinstance
(
o
,
DenseConstantMeta
):
return
True
return
False
class
DenseTensorConstant
(
TensorType
,
metaclass
=
DenseConstantMeta
):
r"""A `Constant` for dense tensors.
Instances of this class and `TensorConstant`\s are considered dense
`Constant`\s.
"""
tests/sparse/test_basic.py
浏览文件 @
fdf2a23a
...
@@ -12,7 +12,7 @@ from aesara.compile.function import function
...
@@ -12,7 +12,7 @@ from aesara.compile.function import function
from
aesara.compile.io
import
In
,
Out
from
aesara.compile.io
import
In
,
Out
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
GradientError
from
aesara.gradient
import
GradientError
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.basic
import
Apply
,
Constant
,
applys_between
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.sparse
import
(
from
aesara.sparse
import
(
...
@@ -78,6 +78,7 @@ from aesara.sparse import (
...
@@ -78,6 +78,7 @@ from aesara.sparse import (
true_dot
,
true_dot
,
)
)
from
aesara.sparse.basic
import
(
from
aesara.sparse.basic
import
(
SparseConstant
,
_is_dense_variable
,
_is_dense_variable
,
_is_sparse
,
_is_sparse
,
_is_sparse_variable
,
_is_sparse_variable
,
...
@@ -1017,22 +1018,45 @@ class TestComparison:
...
@@ -1017,22 +1018,45 @@ class TestComparison:
class
TestConversion
:
class
TestConversion
:
@pytest.mark.skip
def
test_basic
(
self
):
def
test_basic
(
self
):
a
=
at
.
as_tensor_variable
(
np
.
random
.
random
((
5
)))
test_val
=
np
.
random
.
rand
(
5
)
.
astype
(
config
.
floatX
)
a
=
at
.
as_tensor_variable
(
test_val
)
s
=
csc_from_dense
(
a
)
s
=
csc_from_dense
(
a
)
val
=
eval_outputs
([
s
])
val
=
eval_outputs
([
s
])
assert
str
(
val
.
dtype
)
==
"float64"
assert
str
(
val
.
dtype
)
==
config
.
floatX
assert
val
.
format
==
"csc"
assert
val
.
format
==
"csc"
@pytest.mark.skip
a
=
at
.
as_tensor_variable
(
test_val
)
def
test_basic_1
(
self
):
a
=
at
.
as_tensor_variable
(
np
.
random
.
random
((
5
)))
s
=
csr_from_dense
(
a
)
s
=
csr_from_dense
(
a
)
val
=
eval_outputs
([
s
])
val
=
eval_outputs
([
s
])
assert
str
(
val
.
dtype
)
==
"float64"
assert
str
(
val
.
dtype
)
==
config
.
floatX
assert
val
.
format
==
"csr"
assert
val
.
format
==
"csr"
test_val
=
np
.
eye
(
3
)
.
astype
(
config
.
floatX
)
a
=
sp
.
sparse
.
csr_matrix
(
test_val
)
s
=
as_sparse_or_tensor_variable
(
a
)
res
=
at
.
as_tensor_variable
(
s
)
assert
isinstance
(
res
,
SparseConstant
)
a
=
sp
.
sparse
.
csr_matrix
(
test_val
)
s
=
as_sparse_or_tensor_variable
(
a
)
from
aesara.tensor.exceptions
import
NotScalarConstantError
with
pytest
.
raises
(
NotScalarConstantError
):
at
.
get_scalar_constant_value
(
s
,
only_process_constants
=
True
)
# TODO:
# def test_sparse_as_tensor_variable(self):
# csr = sp.sparse.csr_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csr"
#
# csr = sp.sparse.csc_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csc"
def
test_dense_from_sparse
(
self
):
def
test_dense_from_sparse
(
self
):
# call dense_from_sparse
# call dense_from_sparse
for
t
in
_mtypes
:
for
t
in
_mtypes
:
...
@@ -1591,6 +1615,32 @@ class TestDots(utt.InferShapeTester):
...
@@ -1591,6 +1615,32 @@ class TestDots(utt.InferShapeTester):
)
)
f
(
i
,
a
)
f
(
i
,
a
)
def
test_tensor_dot_types
(
self
):
x
=
sparse
.
csc_matrix
(
"x"
)
x_d
=
at
.
matrix
(
"x_d"
)
y
=
sparse
.
csc_matrix
(
"y"
)
res
=
at
.
dot
(
x
,
y
)
op_types
=
set
(
type
(
n
.
op
)
for
n
in
applys_between
([
x
,
y
],
[
res
]))
assert
sparse
.
basic
.
StructuredDot
in
op_types
assert
at
.
math
.
Dot
not
in
op_types
res
=
at
.
dot
(
x_d
,
y
)
op_types
=
set
(
type
(
n
.
op
)
for
n
in
applys_between
([
x
,
y
],
[
res
]))
assert
sparse
.
basic
.
StructuredDot
in
op_types
assert
at
.
math
.
Dot
not
in
op_types
res
=
at
.
dot
(
x
,
x_d
)
op_types
=
set
(
type
(
n
.
op
)
for
n
in
applys_between
([
x
,
y
],
[
res
]))
assert
sparse
.
basic
.
StructuredDot
in
op_types
assert
at
.
math
.
Dot
not
in
op_types
res
=
at
.
dot
(
at
.
second
(
1
,
x
),
y
)
op_types
=
set
(
type
(
n
.
op
)
for
n
in
applys_between
([
x
,
y
],
[
res
]))
assert
sparse
.
basic
.
StructuredDot
in
op_types
assert
at
.
math
.
Dot
not
in
op_types
def
test_csr_dense_grad
(
self
):
def
test_csr_dense_grad
(
self
):
# shortcut: testing csc in float32, testing csr in float64
# shortcut: testing csc in float32, testing csr in float64
...
...
tests/sparse/test_type.py
浏览文件 @
fdf2a23a
import
pytest
from
aesara.sparse
import
matrix
as
sp_matrix
from
aesara.sparse.type
import
SparseType
from
aesara.sparse.type
import
SparseType
from
aesara.tensor
import
dmatrix
def
test_clone
():
def
test_clone
():
st
=
SparseType
(
"csr"
,
"float64"
)
st
=
SparseType
(
"csr"
,
"float64"
)
assert
st
==
st
.
clone
()
assert
st
==
st
.
clone
()
def
test_Sparse_convert_variable
():
x
=
dmatrix
(
name
=
"x"
)
y
=
sp_matrix
(
"csc"
,
dtype
=
"float64"
,
name
=
"y"
)
z
=
sp_matrix
(
"csr"
,
dtype
=
"float64"
,
name
=
"z"
)
assert
y
.
type
.
convert_variable
(
z
)
is
None
# TODO FIXME: This is a questionable result, because `x.type` is associated
# with a dense `Type`, but, since `TensorType` is a base class of `Sparse`,
# we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that.
assert
x
.
type
.
convert_variable
(
y
)
is
y
# TODO FIXME: We should be able to do this.
with
pytest
.
raises
(
NotImplementedError
):
y
.
type
.
convert_variable
(
x
)
tests/tensor/test_var.py
浏览文件 @
fdf2a23a
...
@@ -6,6 +6,7 @@ import aesara
...
@@ -6,6 +6,7 @@ import aesara
import
tests.unittest_tools
as
utt
import
tests.unittest_tools
as
utt
from
aesara.graph.basic
import
Constant
,
equal_computations
from
aesara.graph.basic
import
Constant
,
equal_computations
from
aesara.tensor
import
get_vector_length
from
aesara.tensor
import
get_vector_length
from
aesara.tensor.basic
import
constant
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.math
import
dot
from
aesara.tensor.math
import
dot
from
aesara.tensor.subtensor
import
AdvancedSubtensor
,
Subtensor
from
aesara.tensor.subtensor
import
AdvancedSubtensor
,
Subtensor
...
@@ -21,7 +22,12 @@ from aesara.tensor.type import (
...
@@ -21,7 +22,12 @@ from aesara.tensor.type import (
tensor3
,
tensor3
,
)
)
from
aesara.tensor.type_other
import
MakeSlice
from
aesara.tensor.type_other
import
MakeSlice
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
from
aesara.tensor.var
import
(
DenseTensorConstant
,
DenseTensorVariable
,
TensorConstant
,
TensorVariable
,
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -247,3 +253,14 @@ def test_get_vector_length():
...
@@ -247,3 +253,14 @@ def test_get_vector_length():
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
None
,)))
x
=
TensorVariable
(
TensorType
(
"int64"
,
(
None
,)))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
get_vector_length
(
x
)
get_vector_length
(
x
)
def
test_dense_types
():
x
=
matrix
()
assert
isinstance
(
x
,
DenseTensorVariable
)
assert
not
isinstance
(
x
,
DenseTensorConstant
)
x
=
constant
(
1
)
assert
not
isinstance
(
x
,
DenseTensorVariable
)
assert
isinstance
(
x
,
DenseTensorConstant
)
tests/test_ifelse.py
浏览文件 @
fdf2a23a
...
@@ -335,13 +335,13 @@ class TestIfelse(utt.OptimizationTestMixin):
...
@@ -335,13 +335,13 @@ class TestIfelse(utt.OptimizationTestMixin):
z
=
aesara
.
sparse
.
matrix
(
"csr"
,
dtype
=
self
.
dtype
,
name
=
"z"
)
z
=
aesara
.
sparse
.
matrix
(
"csr"
,
dtype
=
self
.
dtype
,
name
=
"z"
)
cond
=
iscalar
(
"cond"
)
cond
=
iscalar
(
"cond"
)
with
pytest
.
raises
(
Type
Error
):
with
pytest
.
raises
(
NotImplemented
Error
):
ifelse
(
cond
,
x
,
y
)
ifelse
(
cond
,
x
,
y
)
with
pytest
.
raises
(
Type
Error
):
with
pytest
.
raises
(
NotImplemented
Error
):
ifelse
(
cond
,
y
,
x
)
ifelse
(
cond
,
y
,
x
)
with
pytest
.
raises
(
Type
Error
):
with
pytest
.
raises
(
NotImplemented
Error
):
ifelse
(
cond
,
x
,
z
)
ifelse
(
cond
,
x
,
z
)
with
pytest
.
raises
(
Type
Error
):
with
pytest
.
raises
(
NotImplemented
Error
):
ifelse
(
cond
,
z
,
x
)
ifelse
(
cond
,
z
,
x
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
ifelse
(
cond
,
y
,
z
)
ifelse
(
cond
,
y
,
z
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论