Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d251eb7e
提交
d251eb7e
authored
2月 06, 2012
作者:
David Warde-Farley
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP8: fix all instances of E302 (2 blank lines)
上级
50387206
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
50 行增加
和
5 行删除
+50
-5
basic.py
theano/sparse/basic.py
+50
-5
没有找到文件。
theano/sparse/basic.py
浏览文件 @
d251eb7e
...
@@ -21,6 +21,7 @@ from theano.tensor import blas
...
@@ -21,6 +21,7 @@ from theano.tensor import blas
sparse_formats
=
[
'csc'
,
'csr'
]
sparse_formats
=
[
'csc'
,
'csr'
]
#TODO: move this decorator to the compile submodule
#TODO: move this decorator to the compile submodule
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
compile
.
optdb
[
'specialize'
]
.
register
((
kwargs
and
kwargs
.
pop
(
'name'
))
or
lopt
.
__name__
,
lopt
,
'fast_run'
,
*
tags
)
compile
.
optdb
[
'specialize'
]
.
register
((
kwargs
and
kwargs
.
pop
(
'name'
))
or
lopt
.
__name__
,
lopt
,
'fast_run'
,
*
tags
)
...
@@ -33,6 +34,7 @@ _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
...
@@ -33,6 +34,7 @@ _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#* new class ``bsr_matrix`` : the Block CSR format
#* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str
=
{
scipy
.
sparse
.
csc_matrix
:
"csc"
,
scipy
.
sparse
.
csr_matrix
:
"csr"
}
_mtype_to_str
=
{
scipy
.
sparse
.
csc_matrix
:
"csc"
,
scipy
.
sparse
.
csr_matrix
:
"csr"
}
def
_is_sparse_variable
(
x
):
def
_is_sparse_variable
(
x
):
"""
"""
@rtype: boolean
@rtype: boolean
...
@@ -41,6 +43,8 @@ def _is_sparse_variable(x):
...
@@ -41,6 +43,8 @@ def _is_sparse_variable(x):
if
not
isinstance
(
x
.
type
,
(
SparseType
,
tensor
.
TensorType
)):
if
not
isinstance
(
x
.
type
,
(
SparseType
,
tensor
.
TensorType
)):
raise
NotImplementedError
(
"this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,"
,
x
)
raise
NotImplementedError
(
"this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,"
,
x
)
return
isinstance
(
x
.
type
,
SparseType
)
return
isinstance
(
x
.
type
,
SparseType
)
def
_is_dense_variable
(
x
):
def
_is_dense_variable
(
x
):
"""
"""
@rtype: boolean
@rtype: boolean
...
@@ -50,6 +54,7 @@ def _is_dense_variable(x):
...
@@ -50,6 +54,7 @@ def _is_dense_variable(x):
raise
NotImplementedError
(
"this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,"
,
x
)
raise
NotImplementedError
(
"this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,"
,
x
)
return
isinstance
(
x
.
type
,
tensor
.
TensorType
)
return
isinstance
(
x
.
type
,
tensor
.
TensorType
)
def
_is_sparse
(
x
):
def
_is_sparse
(
x
):
"""
"""
@rtype: boolean
@rtype: boolean
...
@@ -58,6 +63,8 @@ def _is_sparse(x):
...
@@ -58,6 +63,8 @@ def _is_sparse(x):
if
not
isinstance
(
x
,
(
scipy
.
sparse
.
spmatrix
,
numpy
.
ndarray
)):
if
not
isinstance
(
x
,
(
scipy
.
sparse
.
spmatrix
,
numpy
.
ndarray
)):
raise
NotImplementedError
(
"this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,"
,
x
)
raise
NotImplementedError
(
"this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,"
,
x
)
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
def
_is_dense
(
x
):
def
_is_dense
(
x
):
"""
"""
@rtype: boolean
@rtype: boolean
...
@@ -67,18 +74,19 @@ def _is_dense(x):
...
@@ -67,18 +74,19 @@ def _is_dense(x):
raise
NotImplementedError
(
"this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,"
,
x
)
raise
NotImplementedError
(
"this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,"
,
x
)
return
isinstance
(
x
,
numpy
.
ndarray
)
return
isinstance
(
x
,
numpy
.
ndarray
)
def
_kmap_eq
(
a
,
b
):
def
_kmap_eq
(
a
,
b
):
if
a
is
None
and
b
is
None
:
if
a
is
None
and
b
is
None
:
return
True
return
True
return
numpy
.
all
(
a
==
b
)
return
numpy
.
all
(
a
==
b
)
def
_kmap_hash
(
a
):
def
_kmap_hash
(
a
):
if
a
is
None
:
return
12345
if
a
is
None
:
return
12345
return
hash
(
numpy
.
str
(
a
))
return
hash
(
numpy
.
str
(
a
))
# Wrapper type
# Wrapper type
def
as_sparse_variable
(
x
,
name
=
None
):
def
as_sparse_variable
(
x
,
name
=
None
):
"""
"""
Wrapper around SparseVariable constructor.
Wrapper around SparseVariable constructor.
...
@@ -101,9 +109,9 @@ def as_sparse_variable(x, name=None):
...
@@ -101,9 +109,9 @@ def as_sparse_variable(x, name=None):
return
constant
(
x
,
name
=
name
)
return
constant
(
x
,
name
=
name
)
except
TypeError
:
except
TypeError
:
raise
TypeError
(
"Cannot convert
%
s to SparseType"
%
x
,
type
(
x
))
raise
TypeError
(
"Cannot convert
%
s to SparseType"
%
x
,
type
(
x
))
as_sparse
=
as_sparse_variable
as_sparse
=
as_sparse_variable
def
as_sparse_or_tensor_variable
(
x
,
name
=
None
):
def
as_sparse_or_tensor_variable
(
x
,
name
=
None
):
"""
"""
If we can't make a sparse variable, we try to make a tensor variable.
If we can't make a sparse variable, we try to make a tensor variable.
...
@@ -133,6 +141,7 @@ if 0:
...
@@ -133,6 +141,7 @@ if 0:
except
TypeError
:
except
TypeError
:
raise
TypeError
(
"Could not convert
%
s to SparseType"
%
x
,
type
(
x
))
raise
TypeError
(
"Could not convert
%
s to SparseType"
%
x
,
type
(
x
))
def
sp_ones_like
(
x
):
def
sp_ones_like
(
x
):
data
,
indices
,
indptr
,
shape
=
csm_properties
(
x
)
#TODO: don't restrict to CSM formats
data
,
indices
,
indptr
,
shape
=
csm_properties
(
x
)
#TODO: don't restrict to CSM formats
return
CSM
(
format
=
x
.
format
)(
tensor
.
ones_like
(
data
),
indices
,
indptr
,
shape
)
return
CSM
(
format
=
x
.
format
)(
tensor
.
ones_like
(
data
),
indices
,
indptr
,
shape
)
...
@@ -213,6 +222,7 @@ class SparseVariable(gof.Variable, _sparse_py_operators):
...
@@ -213,6 +222,7 @@ class SparseVariable(gof.Variable, _sparse_py_operators):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
class
SparseConstantSignature
(
tuple
):
class
SparseConstantSignature
(
tuple
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
(
a
,
b
),
(
x
,
y
)
=
self
,
other
(
a
,
b
),
(
x
,
y
)
=
self
,
other
...
@@ -225,6 +235,7 @@ class SparseConstantSignature(tuple):
...
@@ -225,6 +235,7 @@ class SparseConstantSignature(tuple):
(
a
,
b
)
=
self
(
a
,
b
)
=
self
return
hash
(
type
(
self
))
^
hash
(
a
)
^
hash
(
type
(
b
))
return
hash
(
type
(
self
))
^
hash
(
a
)
^
hash
(
type
(
b
))
class
SparseConstant
(
gof
.
Constant
,
_sparse_py_operators
):
class
SparseConstant
(
gof
.
Constant
,
_sparse_py_operators
):
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
...
@@ -242,10 +253,12 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
...
@@ -242,10 +253,12 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
class
SparseValue
(
gof
.
Value
,
_sparse_py_operators
):
class
SparseValue
(
gof
.
Value
,
_sparse_py_operators
):
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
dtype
=
property
(
lambda
self
:
self
.
type
.
dtype
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
class
SparseType
(
gof
.
Type
):
class
SparseType
(
gof
.
Type
):
"""
"""
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
...
@@ -366,8 +379,12 @@ def matrix(format, name=None, dtype=None):
...
@@ -366,8 +379,12 @@ def matrix(format, name=None, dtype=None):
dtype
=
config
.
floatX
dtype
=
config
.
floatX
type
=
SparseType
(
format
=
format
,
dtype
=
dtype
)
type
=
SparseType
(
format
=
format
,
dtype
=
dtype
)
return
type
(
name
)
return
type
(
name
)
def
csc_matrix
(
name
=
None
,
dtype
=
None
):
def
csc_matrix
(
name
=
None
,
dtype
=
None
):
return
matrix
(
'csc'
,
name
,
dtype
)
return
matrix
(
'csc'
,
name
,
dtype
)
def
csr_matrix
(
name
=
None
,
dtype
=
None
):
def
csr_matrix
(
name
=
None
,
dtype
=
None
):
return
matrix
(
'csr'
,
name
,
dtype
)
return
matrix
(
'csr'
,
name
,
dtype
)
# for more dtypes, call SparseType(format, dtype)
# for more dtypes, call SparseType(format, dtype)
...
@@ -378,6 +395,7 @@ csr_dmatrix = SparseType(format='csr', dtype='float64')
...
@@ -378,6 +395,7 @@ csr_dmatrix = SparseType(format='csr', dtype='float64')
csc_fmatrix
=
SparseType
(
format
=
'csc'
,
dtype
=
'float32'
)
csc_fmatrix
=
SparseType
(
format
=
'csc'
,
dtype
=
'float32'
)
csr_fmatrix
=
SparseType
(
format
=
'csr'
,
dtype
=
'float32'
)
csr_fmatrix
=
SparseType
(
format
=
'csr'
,
dtype
=
'float32'
)
# CONSTRUCTION
# CONSTRUCTION
class
CSMProperties
(
gof
.
Op
):
class
CSMProperties
(
gof
.
Op
):
"""Extract all of .data .indices and .indptr"""
"""Extract all of .data .indices and .indptr"""
...
@@ -427,11 +445,20 @@ class CSMProperties(gof.Op):
...
@@ -427,11 +445,20 @@ class CSMProperties(gof.Op):
else
:
else
:
return
[
CSR
(
'csm'
)(
g_data
,
indices
,
indptr
,
shape
)]
return
[
CSR
(
'csm'
)(
g_data
,
indices
,
indptr
,
shape
)]
csm_properties
=
CSMProperties
()
#don't make this a function or it breaks some optimizations below
csm_properties
=
CSMProperties
()
#don't make this a function or it breaks some optimizations below
def
csm_data
(
csm
):
return
csm_properties
(
csm
)[
0
]
def
csm_data
(
csm
):
return
csm_properties
(
csm
)[
0
]
def
csm_indices
(
csm
):
return
csm_properties
(
csm
)[
1
]
def
csm_indices
(
csm
):
return
csm_properties
(
csm
)[
1
]
def
csm_indptr
(
csm
):
return
csm_properties
(
csm
)[
2
]
def
csm_indptr
(
csm
):
return
csm_properties
(
csm
)[
2
]
def
csm_shape
(
csm
):
return
csm_properties
(
csm
)[
3
]
def
csm_shape
(
csm
):
return
csm_properties
(
csm
)[
3
]
class
CSM
(
gof
.
Op
):
class
CSM
(
gof
.
Op
):
"""Construct a CSC or CSR matrix from the internal representation """
"""Construct a CSC or CSR matrix from the internal representation """
view_map
=
{
0
:[
0
]}
#should view the other inputs too, but viewing multiple inputs is not
view_map
=
{
0
:[
0
]}
#should view the other inputs too, but viewing multiple inputs is not
...
@@ -536,6 +563,7 @@ class CSM(gof.Op):
...
@@ -536,6 +563,7 @@ class CSM(gof.Op):
CSC
=
CSM
(
'csc'
)
CSC
=
CSM
(
'csc'
)
CSR
=
CSM
(
'csr'
)
CSR
=
CSM
(
'csr'
)
class
CSMGrad
(
gof
.
op
.
Op
):
class
CSMGrad
(
gof
.
op
.
Op
):
def
__init__
(
self
,
kmap
=
None
):
def
__init__
(
self
,
kmap
=
None
):
self
.
kmap
=
kmap
self
.
kmap
=
kmap
...
@@ -563,6 +591,7 @@ class CSMGrad(gof.op.Op):
...
@@ -563,6 +591,7 @@ class CSMGrad(gof.op.Op):
g_data
[
0
]
=
grad
g_data
[
0
]
=
grad
csm_grad
=
CSMGrad
csm_grad
=
CSMGrad
@gof.local_optimizer
([
csm_properties
])
@gof.local_optimizer
([
csm_properties
])
def
skip_pack_csc01
(
node
):
def
skip_pack_csc01
(
node
):
"""if we find csm_properties(CSM(*args)), then we can replace that with the *args
"""if we find csm_properties(CSM(*args)), then we can replace that with the *args
...
@@ -580,7 +609,6 @@ def skip_pack_csc01(node):
...
@@ -580,7 +609,6 @@ def skip_pack_csc01(node):
register_specialize
(
skip_pack_csc01
)
register_specialize
(
skip_pack_csc01
)
#
#
# Conversion
# Conversion
#
#
...
@@ -617,6 +645,7 @@ class DenseFromSparse(gof.op.Op):
...
@@ -617,6 +645,7 @@ class DenseFromSparse(gof.op.Op):
return
[
ishape
]
return
[
ishape
]
dense_from_sparse
=
DenseFromSparse
()
dense_from_sparse
=
DenseFromSparse
()
class
SparseFromDense
(
gof
.
op
.
Op
):
class
SparseFromDense
(
gof
.
op
.
Op
):
def
__init__
(
self
,
format
):
def
__init__
(
self
,
format
):
self
.
format
=
format
self
.
format
=
format
...
@@ -817,6 +846,7 @@ class Transpose(gof.op.Op):
...
@@ -817,6 +846,7 @@ class Transpose(gof.op.Op):
return
transpose
(
gz
),
return
transpose
(
gz
),
transpose
=
Transpose
()
transpose
=
Transpose
()
class
Neg
(
gof
.
op
.
Op
):
class
Neg
(
gof
.
op
.
Op
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
...
@@ -833,6 +863,7 @@ class Neg(gof.op.Op):
...
@@ -833,6 +863,7 @@ class Neg(gof.op.Op):
return
-
gz
,
return
-
gz
,
neg
=
Neg
()
neg
=
Neg
()
class
AddSS
(
gof
.
op
.
Op
):
class
AddSS
(
gof
.
op
.
Op
):
'''Add two sparse matrices '''
'''Add two sparse matrices '''
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -858,6 +889,8 @@ class AddSS(gof.op.Op):
...
@@ -858,6 +889,8 @@ class AddSS(gof.op.Op):
assert
_is_sparse_variable
(
gz
)
assert
_is_sparse_variable
(
gz
)
return
gz
,
gz
return
gz
,
gz
add_s_s
=
AddSS
()
add_s_s
=
AddSS
()
class
AddSD
(
gof
.
op
.
Op
):
class
AddSD
(
gof
.
op
.
Op
):
''' Add a sparse and a dense matrix '''
''' Add a sparse and a dense matrix '''
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -885,6 +918,8 @@ class AddSD(gof.op.Op):
...
@@ -885,6 +918,8 @@ class AddSD(gof.op.Op):
assert
_is_dense_variable
(
gz
)
assert
_is_dense_variable
(
gz
)
return
sp_ones_like
(
x
)
*
gz
,
gz
return
sp_ones_like
(
x
)
*
gz
,
gz
add_s_d
=
AddSD
()
add_s_d
=
AddSD
()
def
add
(
x
,
y
):
def
add
(
x
,
y
):
"""
"""
Add two matrices, at least one of which is sparse.
Add two matrices, at least one of which is sparse.
...
@@ -900,11 +935,12 @@ def add(x,y):
...
@@ -900,11 +935,12 @@ def add(x,y):
elif
x_is_sparse_variable
and
not
y_is_sparse_variable
:
return
add_s_d
(
x
,
y
)
elif
x_is_sparse_variable
and
not
y_is_sparse_variable
:
return
add_s_d
(
x
,
y
)
elif
y_is_sparse_variable
and
not
x_is_sparse_variable
:
return
add_s_d
(
y
,
x
)
elif
y_is_sparse_variable
and
not
x_is_sparse_variable
:
return
add_s_d
(
y
,
x
)
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
def
sub
(
x
,
y
):
def
sub
(
x
,
y
):
return
x
+
(
-
y
)
return
x
+
(
-
y
)
class
MulSS
(
gof
.
op
.
Op
):
class
MulSS
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a sparse '''
''' Elementwise multiply a sparse and a sparse '''
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -928,6 +964,8 @@ class MulSS(gof.op.Op):
...
@@ -928,6 +964,8 @@ class MulSS(gof.op.Op):
def
grad
(
self
,
(
x
,
y
),
(
gz
,)):
def
grad
(
self
,
(
x
,
y
),
(
gz
,)):
return
y
*
gz
,
x
*
gz
return
y
*
gz
,
x
*
gz
mul_s_s
=
MulSS
()
mul_s_s
=
MulSS
()
class
MulSD
(
gof
.
op
.
Op
):
class
MulSD
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a ndarray '''
''' Elementwise multiply a sparse and a ndarray '''
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -995,6 +1033,8 @@ class MulSD(gof.op.Op):
...
@@ -995,6 +1033,8 @@ class MulSD(gof.op.Op):
assert
_is_sparse_variable
(
gz
)
assert
_is_sparse_variable
(
gz
)
return
y
*
gz
,
x
*
gz
return
y
*
gz
,
x
*
gz
mul_s_d
=
MulSD
()
mul_s_d
=
MulSD
()
def
mul
(
x
,
y
):
def
mul
(
x
,
y
):
"""
"""
Multiply (elementwise) two matrices, at least one of which is sparse.
Multiply (elementwise) two matrices, at least one of which is sparse.
...
@@ -1011,6 +1051,7 @@ def mul(x,y):
...
@@ -1011,6 +1051,7 @@ def mul(x,y):
elif
y_is_sparse_variable
and
not
x_is_sparse_variable
:
return
mul_s_d
(
y
,
x
)
elif
y_is_sparse_variable
and
not
x_is_sparse_variable
:
return
mul_s_d
(
y
,
x
)
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
###############
###############
#
#
# StructuredDot
# StructuredDot
...
@@ -1073,9 +1114,9 @@ class StructuredDot(gof.Op):
...
@@ -1073,9 +1114,9 @@ class StructuredDot(gof.Op):
# ga = g_out x b.T
# ga = g_out x b.T
# gb = a.T x g_out
# gb = a.T x g_out
return
[
structured_dot_grad
(
a
,
b
,
g_out
),
structured_dot
(
a
.
T
,
g_out
)]
return
[
structured_dot_grad
(
a
,
b
,
g_out
),
structured_dot
(
a
.
T
,
g_out
)]
_structured_dot
=
StructuredDot
()
_structured_dot
=
StructuredDot
()
def
structured_dot
(
x
,
y
):
def
structured_dot
(
x
,
y
):
"""
"""
@todo: Maybe the triple-transposition formulation (when x is dense)
@todo: Maybe the triple-transposition formulation (when x is dense)
...
@@ -1096,6 +1137,7 @@ def structured_dot(x, y):
...
@@ -1096,6 +1137,7 @@ def structured_dot(x, y):
assert
y_is_sparse_variable
assert
y_is_sparse_variable
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
class
StructuredDotCSC
(
gof
.
Op
):
class
StructuredDotCSC
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
...
@@ -1262,6 +1304,7 @@ class StructuredDotCSC(gof.Op):
...
@@ -1262,6 +1304,7 @@ class StructuredDotCSC(gof.Op):
return
(
2
,)
return
(
2
,)
sd_csc
=
StructuredDotCSC
()
sd_csc
=
StructuredDotCSC
()
class
StructuredDotCSR
(
gof
.
Op
):
class
StructuredDotCSR
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
...
@@ -1394,6 +1437,7 @@ class StructuredDotCSR(gof.Op):
...
@@ -1394,6 +1437,7 @@ class StructuredDotCSR(gof.Op):
return
(
1
,)
return
(
1
,)
sd_csr
=
StructuredDotCSR
()
sd_csr
=
StructuredDotCSR
()
# register a specialization to replace StructuredDot -> StructuredDotCSx
# register a specialization to replace StructuredDot -> StructuredDotCSx
@gof.local_optimizer
([
_structured_dot
])
@gof.local_optimizer
([
_structured_dot
])
def
local_structured_dot
(
node
):
def
local_structured_dot
(
node
):
...
@@ -1414,6 +1458,7 @@ def local_structured_dot(node):
...
@@ -1414,6 +1458,7 @@ def local_structured_dot(node):
# involved. dimension mismatches are hard to detect sensibly.
# involved. dimension mismatches are hard to detect sensibly.
#register_specialize(local_structured_dot)
#register_specialize(local_structured_dot)
def
structured_dot_grad
(
sparse_A
,
dense_B
,
ga
):
def
structured_dot_grad
(
sparse_A
,
dense_B
,
ga
):
if
sparse_A
.
type
.
format
in
(
'csc'
,
'csr'
):
if
sparse_A
.
type
.
format
in
(
'csc'
,
'csr'
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论