Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
eb6569b5
提交
eb6569b5
authored
7月 16, 2009
作者:
bergstra@tikuanyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ModuleCache works without the pkl file now, more robust to various errors
上级
ada92aea
全部展开
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
125 行增加
和
8 行删除
+125
-8
__init__.py
theano/__init__.py
+10
-0
test_inplace_opt_for_value.py
theano/compile/tests/test_inplace_opt_for_value.py
+6
-1
cc.py
theano/gof/cc.py
+31
-7
cmodule.py
theano/gof/cmodule.py
+0
-0
op.py
theano/gof/op.py
+10
-0
test_cc.py
theano/gof/tests/test_cc.py
+5
-0
type.py
theano/gof/type.py
+10
-0
basic.py
theano/sparse/basic.py
+49
-0
blas.py
theano/tensor/blas.py
+4
-0
没有找到文件。
theano/__init__.py
浏览文件 @
eb6569b5
...
@@ -147,3 +147,13 @@ def dot(l, r):
...
@@ -147,3 +147,13 @@ def dot(l, r):
raise
NotImplementedError
(
"Dot failed for the following reaons:"
,
(
e0
,
e1
))
raise
NotImplementedError
(
"Dot failed for the following reaons:"
,
(
e0
,
e1
))
return
rval
return
rval
###
# Set a default logger
#
import
logging
logging_default_handler
=
logging
.
StreamHandler
()
logging
.
getLogger
(
"theano"
)
.
addHandler
(
logging_default_handler
)
logging
.
getLogger
(
"theano"
)
.
setLevel
(
logging
.
WARNING
)
theano/compile/tests/test_inplace_opt_for_value.py
浏览文件 @
eb6569b5
...
@@ -81,6 +81,11 @@ class TanhRnn(Op):
...
@@ -81,6 +81,11 @@ class TanhRnn(Op):
in which z[0] = z0.
in which z[0] = z0.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
z0
,
A
):
def
make_node
(
self
,
x
,
z0
,
A
):
"""
"""
...
@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
...
@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
,
other
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
def
make_node
(
self
,
A
,
z
,
gz
):
def
make_node
(
self
,
A
,
z
,
gz
):
...
...
theano/gof/cc.py
浏览文件 @
eb6569b5
...
@@ -26,10 +26,10 @@ import cmodule
...
@@ -26,10 +26,10 @@ import cmodule
import
logging
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
def
info
(
*
args
):
def
info
(
*
args
):
sys
.
stderr
.
write
(
'INFO:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
def
debug
(
*
args
):
sys
.
stderr
.
write
(
'DEBUG:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
def
warning
(
*
args
):
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
...
@@ -367,6 +367,7 @@ class CLinker(link.Linker):
...
@@ -367,6 +367,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order.
# The orphans field is listified to ensure a consistent order.
self
.
orphans
=
list
(
r
for
r
in
self
.
variables
if
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
orphans
=
list
(
r
for
r
in
self
.
variables
if
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
consts
=
[]
self
.
node_order
=
env
.
toposort
()
self
.
node_order
=
env
.
toposort
()
def
code_gen
(
self
):
def
code_gen
(
self
):
...
@@ -390,7 +391,7 @@ class CLinker(link.Linker):
...
@@ -390,7 +391,7 @@ class CLinker(link.Linker):
env
=
self
.
env
env
=
self
.
env
consts
=
[]
self
.
consts
=
[]
symbol
=
{}
symbol
=
{}
...
@@ -428,7 +429,7 @@ class CLinker(link.Linker):
...
@@ -428,7 +429,7 @@ class CLinker(link.Linker):
if
isinstance
(
variable
,
graph
.
Constant
):
if
isinstance
(
variable
,
graph
.
Constant
):
try
:
try
:
symbol
[
variable
]
=
"("
+
variable
.
type
.
c_literal
(
variable
.
data
)
+
")"
symbol
[
variable
]
=
"("
+
variable
.
type
.
c_literal
(
variable
.
data
)
+
")"
consts
.
append
(
variable
)
self
.
consts
.
append
(
variable
)
self
.
orphans
.
remove
(
variable
)
self
.
orphans
.
remove
(
variable
)
continue
continue
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
...
@@ -530,6 +531,11 @@ class CLinker(link.Linker):
...
@@ -530,6 +531,11 @@ class CLinker(link.Linker):
self
.
tasks
=
tasks
self
.
tasks
=
tasks
all
=
self
.
inputs
+
self
.
outputs
+
self
.
orphans
all
=
self
.
inputs
+
self
.
outputs
+
self
.
orphans
if
(
self
.
init_tasks
,
self
.
tasks
)
!=
self
.
get_init_tasks
():
print
>>
sys
.
stderr
,
"init_tasks
\n
"
,
self
.
init_tasks
print
>>
sys
.
stderr
,
self
.
get_init_tasks
()[
0
]
print
>>
sys
.
stderr
,
"tasks
\n
"
,
self
.
tasks
print
>>
sys
.
stderr
,
self
.
get_init_tasks
()[
1
]
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
# List of indices that should be ignored when passing the arguments
# List of indices that should be ignored when passing the arguments
...
@@ -646,6 +652,14 @@ class CLinker(link.Linker):
...
@@ -646,6 +652,14 @@ class CLinker(link.Linker):
tasks
=
[]
tasks
=
[]
id
=
1
id
=
1
for
v
in
self
.
variables
:
for
v
in
self
.
variables
:
if
v
in
self
.
consts
:
continue
if
v
in
self
.
orphans
and
isinstance
(
v
,
graph
.
Constant
):
try
:
v
.
type
.
c_literal
(
v
.
data
)
#constant will be inlined, no need to get
continue
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
pass
init_tasks
.
append
((
v
,
'init'
,
id
))
init_tasks
.
append
((
v
,
'init'
,
id
))
tasks
.
append
((
v
,
'get'
,
id
+
1
))
tasks
.
append
((
v
,
'get'
,
id
+
1
))
id
+=
2
id
+=
2
...
@@ -687,7 +701,7 @@ class CLinker(link.Linker):
...
@@ -687,7 +701,7 @@ class CLinker(link.Linker):
The signature has the following form:
The signature has the following form:
{{{
{{{
'CLinker.cmodule_key',
'CLinker.cmodule_key',
compilation args, libraries,
op0, (input0.type, input1.type, input0 pos, input1 pos)
op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...)
op1, (...)
...
...
...
@@ -717,6 +731,9 @@ class CLinker(link.Linker):
...
@@ -717,6 +731,9 @@ class CLinker(link.Linker):
env_computed_set
=
set
()
env_computed_set
=
set
()
op_pos
=
{}
# Apply -> topological position
op_pos
=
{}
# Apply -> topological position
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
.
append
(
tuple
(
self
.
compile_args
()))
rval
.
append
(
tuple
(
self
.
libraries
()))
version
=
[]
# assert that every input to every node is one of'
# assert that every input to every node is one of'
# - an env input
# - an env input
...
@@ -735,12 +752,19 @@ class CLinker(link.Linker):
...
@@ -735,12 +752,19 @@ class CLinker(link.Linker):
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
for
opos
,
o
in
enumerate
(
order
):
for
opos
,
o
in
enumerate
(
order
):
version
.
append
(
o
.
op
.
c_code_cache_version
())
for
i
in
o
.
inputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
for
i
in
o
.
outputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
rval
.
append
((
o
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
))
for
i
in
o
.
inputs
)))
rval
.
append
((
o
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
))
for
i
in
o
.
inputs
)))
op_pos
[
o
]
=
opos
op_pos
[
o
]
=
opos
env_computed_set
.
update
(
o
.
outputs
)
env_computed_set
.
update
(
o
.
outputs
)
rval
=
tuple
(
rval
)
for
v
in
version
:
return
rval
if
not
v
:
#one of the ops or types here is unversioned
return
((),
tuple
(
rval
))
return
tuple
(
version
),
tuple
(
rval
)
def
compile_cmodule
(
self
,
location
=
None
):
def
compile_cmodule
(
self
,
location
=
None
):
"""
"""
...
...
theano/gof/cmodule.py
浏览文件 @
eb6569b5
差异被折叠。
点击展开。
theano/gof/op.py
浏览文件 @
eb6569b5
...
@@ -162,6 +162,16 @@ class CLinkerOp(object):
...
@@ -162,6 +162,16 @@ class CLinkerOp(object):
raise
utils
.
MethodNotDefined
(
'
%
s.c_support_code'
\
raise
utils
.
MethodNotDefined
(
'
%
s.c_support_code'
\
%
self
.
__class__
.
__name__
)
%
self
.
__class__
.
__name__
)
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return
(
1
,)
class
PureOp
(
object
):
class
PureOp
(
object
):
"""
"""
An :term:`Op` is a type of operation.
An :term:`Op` is a type of operation.
...
...
theano/gof/tests/test_cc.py
浏览文件 @
eb6569b5
...
@@ -57,6 +57,9 @@ class TDouble(Type):
...
@@ -57,6 +57,9 @@ class TDouble(Type):
free(
%(name)
s_bad_thing);
free(
%(name)
s_bad_thing);
"""
%
locals
()
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
()
tdouble
=
TDouble
()
tdouble
=
TDouble
()
def
double
(
name
):
def
double
(
name
):
...
@@ -83,6 +86,8 @@ class MyOp(Op):
...
@@ -83,6 +86,8 @@ class MyOp(Op):
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
out
[
0
]
=
self
.
impl
(
*
inputs
)
out
[
0
]
=
self
.
impl
(
*
inputs
)
def
c_code_cache_version
(
self
):
return
()
class
Unary
(
MyOp
):
class
Unary
(
MyOp
):
...
...
theano/gof/type.py
浏览文件 @
eb6569b5
...
@@ -210,6 +210,16 @@ class CLinkerType(object):
...
@@ -210,6 +210,16 @@ class CLinkerType(object):
"""
"""
raise
MethodNotDefined
(
"c_support_code"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_support_code"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return
(
1
,)
class
PureType
(
object
):
class
PureType
(
object
):
"""Interface specification for variable type instances.
"""Interface specification for variable type instances.
...
...
theano/sparse/basic.py
浏览文件 @
eb6569b5
...
@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
...
@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
"""
"""
sparse_grad
=
True
sparse_grad
=
True
"""WRITEME"""
"""WRITEME"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
...
@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
...
@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
class
Transpose
(
gof
.
op
.
Op
):
class
Transpose
(
gof
.
op
.
Op
):
format_map
=
{
'csr'
:
'csc'
,
format_map
=
{
'csr'
:
'csc'
,
'csc'
:
'csr'
}
'csc'
:
'csr'
}
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
return
gof
.
Apply
(
self
,
return
gof
.
Apply
(
self
,
...
@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
...
@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
transpose
=
Transpose
()
transpose
=
Transpose
()
class
Neg
(
gof
.
op
.
Op
):
class
Neg
(
gof
.
op
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
@@ -523,6 +535,10 @@ neg = Neg()
...
@@ -523,6 +535,10 @@ neg = Neg()
class
AddSS
(
gof
.
op
.
Op
):
class
AddSS
(
gof
.
op
.
Op
):
'''Add two sparse matrices '''
'''Add two sparse matrices '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
map
(
as_sparse_variable
,
[
x
,
y
])
x
,
y
=
map
(
as_sparse_variable
,
[
x
,
y
])
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
...
@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
add_s_s
=
AddSS
()
add_s_s
=
AddSS
()
class
AddSD
(
gof
.
op
.
Op
):
class
AddSD
(
gof
.
op
.
Op
):
''' Add a sparse and a dense matrix '''
''' Add a sparse and a dense matrix '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -586,6 +606,10 @@ def sub(x,y):
...
@@ -586,6 +606,10 @@ def sub(x,y):
class
MulSS
(
gof
.
op
.
Op
):
class
MulSS
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a ndarray '''
''' Elementwise multiply a sparse and a ndarray '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
as_sparse_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
as_sparse_variable
(
y
)
if
x
.
type
!=
y
.
type
:
if
x
.
type
!=
y
.
type
:
...
@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
...
@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
mul_s_s
=
MulSS
()
mul_s_s
=
MulSS
()
class
MulSD
(
gof
.
op
.
Op
):
class
MulSD
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a ndarray '''
''' Elementwise multiply a sparse and a ndarray '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
...
@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a
,
b
):
def
make_node
(
self
,
a
,
b
):
if
type
(
a
)
is
not
SparseVariable
and
type
(
a
)
is
not
SparseConstant
:
if
type
(
a
)
is
not
SparseVariable
and
type
(
a
)
is
not
SparseConstant
:
raise
TypeError
(
'First argument must be of type SparseVariable or SparseConstant'
);
raise
TypeError
(
'First argument must be of type SparseVariable or SparseConstant'
);
...
@@ -750,6 +782,10 @@ def structured_dot(x, y):
...
@@ -750,6 +782,10 @@ def structured_dot(x, y):
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
class
StructuredDotCSC
(
gof
.
Op
):
class
StructuredDotCSC
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
):
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
):
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
],
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
],
...
@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
...
@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
sd_csc
=
StructuredDotCSC
()
sd_csc
=
StructuredDotCSC
()
class
StructuredDotCSR
(
gof
.
Op
):
class
StructuredDotCSR
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
b
):
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
b
):
self
.
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
self
.
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
b
],
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
b
],
...
@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
...
@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
class
StructuredDotGradCSC
(
gof
.
Op
):
class
StructuredDotGradCSC
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
g_ab
.
dtype
,
(
False
,))])
[
tensor
.
tensor
(
g_ab
.
dtype
,
(
False
,))])
...
@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
...
@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
class
StructuredDotGradCSR
(
gof
.
Op
):
class
StructuredDotGradCSR
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
b
.
dtype
,
(
False
,))])
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
b
.
dtype
,
(
False
,))])
...
@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
...
@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
"""
%
dict
(
locals
(),
**
sub
)
"""
%
dict
(
locals
(),
**
sub
)
sdg_csr
=
StructuredDotGradCSR
()
sdg_csr
=
StructuredDotGradCSR
()
theano/tensor/blas.py
浏览文件 @
eb6569b5
...
@@ -49,6 +49,10 @@ class GemmRelated(Op):
...
@@ -49,6 +49,10 @@ class GemmRelated(Op):
This class provides a kind of templated gemm Op.
This class provides a kind of templated gemm Op.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
c_support_code
(
self
):
def
c_support_code
(
self
):
#return cblas_header_text()
#return cblas_header_text()
mod_str
=
"""
mod_str
=
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论