Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ada92aea
提交
ada92aea
authored
7月 16, 2009
作者:
bergstra@tikuanyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ModuleCache working
上级
1f2d68ea
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
232 行增加
和
69 行删除
+232
-69
cc.py
theano/gof/cc.py
+30
-34
cmodule.py
theano/gof/cmodule.py
+134
-14
op.py
theano/gof/op.py
+1
-1
basic.py
theano/scalar/basic.py
+2
-2
basic.py
theano/tensor/basic.py
+10
-6
elemwise.py
theano/tensor/elemwise.py
+24
-6
nnet.py
theano/tensor/nnet.py
+31
-6
没有找到文件。
theano/gof/cc.py
浏览文件 @
ada92aea
...
...
@@ -29,6 +29,7 @@ def info(*args):
sys
.
stderr
.
write
(
'INFO:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
sys
.
stderr
.
write
(
'DEBUG:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
...
...
@@ -39,19 +40,14 @@ def error(*args):
from
.callcache
import
CallCache
_timers
=
{}
_module_cache
=
None
def
get_module_cache
():
global
_module_cache
if
_module_cache
is
None
:
_module_cache
=
CallCache
()
#TODO: put a filename here for persistence
return
_module_cache
return
cmodule
.
get_module_cache
(
get_compiledir
())
_persistent_module_cache
=
None
def
get_persistent_module_cache
():
global
_persistent_module_cache
if
_persistent_module_cache
is
None
:
_persistent_module_cache
=
CallCache
(
)
#TODO: put a filename here for persistence
_persistent_module_cache
=
CallCache
(
os
.
path
.
join
(
get_compiledir
(),
'persistent_cache'
))
return
_persistent_module_cache
class
CodeBlock
:
...
...
@@ -746,7 +742,29 @@ class CLinker(link.Linker):
rval
=
tuple
(
rval
)
return
rval
def
compile_cmodule
(
self
):
def
compile_cmodule
(
self
,
location
=
None
):
"""
This method is a callback for `ModuleCache.module_from_key`
"""
location
=
get_compiledir
()
if
location
is
None
else
location
mod
=
self
.
build_dynamic_module
()
get_lock
()
try
:
debug
(
"LOCATION"
,
location
)
module
=
self
.
module_compile_str
(
module_name
=
mod
.
name
,
src_code
=
mod
.
code
(),
location
=
location
,
include_dirs
=
[],
libs
=
self
.
libraries
(),
preargs
=
self
.
compile_args
())
finally
:
release_lock
()
return
module
def
build_dynamic_module
(
self
):
"""Generate the code for this module, compile it, return the imported dynamic module.
"""
self
.
code_gen
()
...
...
@@ -755,18 +773,7 @@ class CLinker(link.Linker):
cthunk
=
object
()
# dummy so weave can get the type
mod
=
cmodule
.
DynamicModule
(
module_name
)
if
0
:
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage
=
[
x
for
i
,
x
in
enumerate
(
out_storage
)
if
(
i
+
len
(
in_storage
))
not
in
self
.
dupidx
]
in_storage
=
[
x
for
i
,
x
in
enumerate
(
in_storage
)
if
i
not
in
self
.
dupidx
]
argnames
=
[
"i
%
i"
%
i
for
i
in
xrange
(
len
(
in_storage
))]
\
+
[
"o
%
i"
%
i
for
i
in
xrange
(
len
(
out_storage
))]
\
+
[
"orph
%
i"
%
i
for
i
in
xrange
(
len
(
self
.
orphans
))]
# The code of instantiate
#code = self.instantiate_code(1+len(argnames)) #the 1 is for error_storage
code
=
self
.
instantiate_code
(
1
+
len
(
self
.
args
))
#the 1 is for error_storage
instantiate
=
cmodule
.
ExtFunction
(
'instantiate'
,
code
,
method
=
cmodule
.
METH_VARARGS
)
#['error_storage'] + argnames,
...
...
@@ -799,19 +806,7 @@ class CLinker(link.Linker):
for
header
in
self
.
headers
():
mod
.
add_include
(
header
)
get_lock
()
try
:
module
=
self
.
module_compile_str
(
module_name
=
mod
.
name
,
src_code
=
mod
.
code
(),
location
=
get_compiledir
(),
include_dirs
=
[],
libs
=
self
.
libraries
(),
preargs
=
self
.
compile_args
())
finally
:
release_lock
()
return
module
return
mod
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
):
...
...
@@ -831,9 +826,10 @@ class CLinker(link.Linker):
except
KeyError
:
key
=
None
if
key
is
None
:
#if we can't get a key, then forget the cache mechanism
module
=
self
.
compile_cmodule
()
else
:
module
=
get_module_cache
()
.
call
(
self
.
compile_cmodule
,
key
=
key
)
module
=
get_module_cache
()
.
module_from_key
(
key
=
key
,
fn
=
self
.
compile_cmodule
)
vars
=
self
.
inputs
+
self
.
outputs
+
self
.
orphans
# List of indices that should be ignored when passing the arguments
...
...
theano/gof/cmodule.py
浏览文件 @
ada92aea
"""Generate and compile C modules for Python
"""
import
os
,
tempfile
,
StringIO
,
sys
,
logging
,
subprocess
import
os
,
tempfile
,
StringIO
,
sys
,
logging
,
subprocess
,
cPickle
,
atexit
_logger
=
logging
.
getLogger
(
"theano.gof.cmodule"
)
def
warning
(
*
args
):
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
warning
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
info
(
*
args
):
sys
.
stderr
.
write
(
'INFO:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
...
...
@@ -115,10 +115,138 @@ class DynamicModule(object):
#TODO: add_type
def
dlimport
(
fullpath
,
suffix
=
None
):
"""Dynamically load a .so, .dll, or .py file
:type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module
:param suffix: a suffix to strip from the end of fullpath to get the import name
:type suffix: string
:returns: the dynamically loaded module (from __import__)
"""
if
suffix
is
None
:
if
fullpath
.
endswith
(
'.so'
):
suffix
=
'.so'
elif
fullpath
.
endswith
(
'.dll'
):
suffix
=
'.dll'
elif
fullpath
.
endswith
(
'.py'
):
suffix
=
'.py'
else
:
suffix
=
''
rval
=
None
if
fullpath
.
endswith
(
suffix
):
module_name
=
'.'
.
join
(
fullpath
.
split
(
os
.
path
.
sep
)[
-
2
:])[:
-
len
(
suffix
)]
else
:
raise
ValueError
(
'path has wrong suffix'
,
(
fullpath
,
suffix
))
workdir
=
fullpath
[:
-
len
(
module_name
)
-
1
-
len
(
suffix
)]
#debug("WORKDIR", workdir)
#debug("module_name", module_name)
pathcopy
=
list
(
sys
.
path
)
sys
.
path
=
[
workdir
]
try
:
rval
=
__import__
(
module_name
,
{},
{},
[
module_name
])
if
not
rval
:
error
(
'__import__ failed'
,
fullpath
)
finally
:
sys
.
path
=
pathcopy
assert
fullpath
.
startswith
(
rval
.
__file__
)
return
rval
class
ModuleCache
(
object
):
def
__init__
(
self
,
dirname
,
force_fresh
=
False
):
self
.
dirname
=
dirname
self
.
module_from_name
=
{}
self
.
name_from_key_filename
=
os
.
path
.
join
(
self
.
dirname
,
'module_cache.pkl'
)
self
.
name_from_key
=
{}
self
.
stats
=
[
0
,
0
,
0
]
if
not
force_fresh
:
try
:
f
=
file
(
self
.
name_from_key_filename
,
'r'
)
self
.
name_from_key
=
cPickle
.
load
(
f
)
debug
(
'ModuleCache loaded'
,
len
(
self
.
name_from_key
))
f
.
close
()
except
(
IOError
,
EOFError
):
debug
(
'cache load failed. Using fresh cache'
)
pass
def
persist
(
self
):
f
=
file
(
self
.
name_from_key_filename
,
'w'
)
cPickle
.
dump
(
self
.
name_from_key
,
f
)
f
.
close
()
def
module_from_key
(
self
,
key
,
fn
=
None
):
rval
=
None
if
key
in
self
.
name_from_key
:
# we have seen this key either in this process or previously
#debug('OLD KEY HASH', hash(key), hash(key[1][0]), key[1][0])
name
=
self
.
name_from_key
[
key
]
if
name
not
in
self
.
module_from_name
:
#debug('loading name', name)
self
.
module_from_name
[
name
]
=
dlimport
(
name
)
self
.
stats
[
1
]
+=
1
else
:
self
.
stats
[
0
]
+=
1
rval
=
self
.
module_from_name
[
name
]
else
:
# we have never seen this key before
location
=
tempfile
.
mkdtemp
(
dir
=
self
.
dirname
)
#debug("LOCATION*", location)
try
:
module
=
fn
(
location
=
location
)
# WILL FAIL FOR BAD C CODE
finally
:
# >>TODO: erase location
pass
debug
(
'NEW KEY HASH'
,
hash
(
key
),
hash
(
key
[
1
][
0
]),
key
[
1
][
0
])
for
k
,
n
in
self
.
name_from_key
.
iteritems
():
if
k
==
key
:
debug
(
"HASH OF RELOAD IS DIFFERENT"
,
hash
(
k
),
hash
(
key
))
print
''
print
hash
(
k
[
0
])
print
hash
(
key
[
0
])
print
''
print
"OLD"
,
print
hash
(
k
[
1
][
0
])
print
k
[
1
][
0
]
.
rehash
()
print
""
print
"NEW"
,
hash
(
key
[
1
][
0
]),
key
[
1
][
0
]
.
rehash
()
print
''
print
hash
(
k
[
1
][
1
])
print
hash
(
key
[
1
][
1
])
assert
k
!=
key
name
=
module
.
__file__
#debug("LOCATION**", location)
#debug("NAME**", name)
assert
name
.
startswith
(
location
)
assert
name
not
in
self
.
module_from_name
assert
key
not
in
self
.
name_from_key
self
.
name_from_key
[
key
]
=
name
self
.
module_from_name
[
name
]
=
module
self
.
stats
[
2
]
+=
1
rval
=
module
#debug('stats', self.stats, sum(self.stats))
return
rval
_module_cache
=
None
def
get_module_cache
(
dirname
):
global
_module_cache
if
_module_cache
is
None
:
_module_cache
=
ModuleCache
(
dirname
,
force_fresh
=
False
)
atexit
.
register
(
_module_cache
.
persist
)
return
_module_cache
def
gcc_module_compile_str
(
module_name
,
src_code
,
location
=
None
,
include_dirs
=
[],
lib_dirs
=
[],
libs
=
[],
preargs
=
[],
tmpdir
=
None
):
#TODO: don't to the dlimport in this function
preargs
=
[]
if
preargs
is
None
else
list
(
preargs
)
preargs
.
append
(
'-fPIC'
)
no_opt
=
False
...
...
@@ -127,7 +255,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
include_dirs
=
[
'/usr/include/python2.6'
]
+
include_dirs
libs
=
[
'python2.6'
]
+
libs
workdir
=
tempfile
.
mkdtemp
(
dir
=
location
)
workdir
=
location
cppfilename
=
os
.
path
.
join
(
workdir
,
'mod.cpp'
)
cppfile
=
file
(
cppfilename
,
'w'
)
...
...
@@ -157,19 +285,12 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
status
=
p
.
wait
()
if
status
:
warning
(
'g++ return status'
,
status
)
error
(
'g++ return status'
,
status
)
else
:
#touch the __init__ file
file
(
os
.
path
.
join
(
workdir
,
"__init__.py"
),
'w'
)
.
close
()
#load the module
sys
.
path
.
insert
(
0
,
workdir
)
try
:
rval
=
__import__
(
module_name
,
{},
{},
[
module_name
])
if
not
rval
:
debug
(
'__import__ failed'
)
finally
:
del
sys
.
path
[
0
]
rval
=
dlimport
(
lib_filename
)
finally
:
warning
(
"TODO: cleanup"
)
...
...
@@ -228,7 +349,6 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
file
(
os
.
path
.
join
(
workdir
,
"__init__.py"
),
'w'
)
.
close
()
#load the module
pathcopy
=
list
(
sys
.
path
)
sys
.
path
.
insert
(
0
,
workdir
)
try
:
rval
=
__import__
(
module_name
,
{},
{},
[
module_name
])
...
...
theano/gof/op.py
浏览文件 @
ada92aea
...
...
@@ -196,7 +196,7 @@ class PureOp(object):
variable or an instance variable.
"""
#############
# make_node #
#############
...
...
theano/scalar/basic.py
浏览文件 @
ada92aea
...
...
@@ -59,7 +59,7 @@ class Scalar(Type):
return
type
(
self
)
==
type
(
other
)
and
other
.
dtype
==
self
.
dtype
def
__hash__
(
self
):
return
hash
(
self
.
dtype
)
return
hash
(
'theano.scalar.Scalar'
)
^
hash
(
self
.
dtype
)
def
dtype_specs
(
self
):
try
:
...
...
@@ -348,7 +348,7 @@ class ScalarOp(Op):
return
test
def
__hash__
(
self
):
return
hash
(
getattr
(
self
,
'output_types_preference'
,
0
))
return
hash
(
type
(
self
)
.
__name__
)
^
hash
(
getattr
(
self
,
'output_types_preference'
,
0
))
def
__str__
(
self
):
if
hasattr
(
self
,
'name'
)
and
self
.
name
:
...
...
theano/tensor/basic.py
浏览文件 @
ada92aea
...
...
@@ -41,6 +41,10 @@ def check_equal_numpy(x, y):
compile
.
register_checker
(
check_equal_numpy
)
def
hashtype
(
self
):
t
=
type
(
self
)
return
hash
(
t
.
__name__
)
^
hash
(
t
.
__module__
)
elemwise
.
hashtype
=
hashtype
__oplist_constructor_list
=
[]
...
...
@@ -305,7 +309,7 @@ class TensorType(Type):
def
__hash__
(
self
):
"""Hash equal for same kinds of TensorType"""
return
hash
(
type
(
self
)
)
^
hash
(
self
.
dtype
)
^
hash
(
self
.
broadcastable
)
return
hash
type
(
self
)
^
hash
(
self
.
dtype
)
^
hash
(
self
.
broadcastable
)
ndim
=
property
(
lambda
self
:
len
(
self
.
broadcastable
),
doc
=
"number of dimensions"
)
"""Number of dimensions
...
...
@@ -732,7 +736,7 @@ class TensorConstantSignature(tuple):
return
(
x
==
a
)
and
(
b
.
shape
==
y
.
shape
)
and
(
numpy
.
all
(
b
==
y
))
def
__hash__
(
self
):
a
,
b
=
self
return
hash
(
type
(
self
)
)
^
hash
(
a
)
^
hash
(
b
.
shape
)
return
hash
type
(
self
)
^
hash
(
a
)
^
hash
(
b
.
shape
)
class
TensorConstant
(
Constant
,
_tensor_py_operators
):
"""Subclass to add the tensor operators to the basic `Constant` class.
...
...
@@ -1607,7 +1611,7 @@ class SetSubtensor(Op):
if
isinstance
(
entry
,
slice
)
else
entry
for
entry
in
self
.
idx_list
)
return
hash
(
type
(
self
)
)
^
hash
(
idx_list
)
^
hash
(
self
.
inplace
)
return
hash
type
(
self
)
^
hash
(
idx_list
)
^
hash
(
self
.
inplace
)
def
__str__
(
self
):
indices
=
[]
...
...
@@ -2125,7 +2129,7 @@ class Flatten(Op):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
outdim
==
other
.
outdim
def
__hash__
(
self
):
return
hash
(
type
(
self
)
)
^
hash
(
self
.
outdim
)
return
hash
type
(
self
)
^
hash
(
self
.
outdim
)
def
make_node
(
self
,
x
):
t_x
=
as_tensor_variable
(
x
)
if
self
.
outdim
<
1
or
(
x
.
ndim
and
self
.
outdim
>
x
.
ndim
):
...
...
@@ -2277,7 +2281,7 @@ class TensorDotGrad(Op):
return
type
(
self
)
==
type
(
other
)
and
self
.
axes
==
other
.
axes
def
__hash__
(
self
):
return
hash
(
type
(
self
)
)
^
hash
(
self
.
axes
)
^
89234
return
hash
type
(
self
)
^
hash
(
self
.
axes
)
^
89234
def
make_node
(
self
,
x
,
y
,
gz
):
assert
isinstance
(
x
,
Variable
)
...
...
@@ -2324,7 +2328,7 @@ class TensorDot(Op):
return
type
(
self
)
==
type
(
other
)
and
self
.
axes
==
other
.
axes
def
__hash__
(
self
):
return
hash
(
type
(
self
)
)
^
hash
(
self
.
axes
)
^
89234
return
hash
type
(
self
)
^
hash
(
self
.
axes
)
^
89234
def
make_node
(
self
,
x
,
y
):
...
...
theano/tensor/elemwise.py
浏览文件 @
ada92aea
...
...
@@ -123,8 +123,15 @@ class DimShuffle(Op):
if
self
.
inplace
:
self
.
view_map
=
{
0
:
[
0
]}
self
.
_hashval
=
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
\
^
hash
(
self
.
new_order
)
^
hash
(
self
.
input_broadcastable
)
self
.
_rehash
()
def
__getstate__
(
self
):
d
=
dict
(
self
.
__dict__
)
del
d
[
'_hashval'
]
return
d
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
_rehash
()
def
make_node
(
self
,
input
):
ib
=
tuple
(
input
.
type
.
broadcastable
)
...
...
@@ -148,6 +155,10 @@ class DimShuffle(Op):
and
self
.
new_order
==
other
.
new_order
\
and
self
.
input_broadcastable
==
other
.
input_broadcastable
def
_rehash
(
self
):
self
.
_hashval
=
hash
(
type
(
self
)
.
__name__
)
^
hash
(
type
(
self
)
.
__module__
)
^
hash
(
self
.
inplace
)
\
^
hash
(
self
.
new_order
)
^
hash
(
self
.
input_broadcastable
)
def
__hash__
(
self
):
return
self
.
_hashval
...
...
@@ -353,15 +364,13 @@ class Elemwise(Op):
self
.
ufunc
=
None
#precompute the hash of this node
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
tuple_items
=
tuple
([
k
for
k
,
v
in
items
]
+
[(
tuple
(
v
)
if
isinstance
(
v
,
(
tuple
,
list
))
else
v
)
for
k
,
v
in
items
])
self
.
_hashval
=
hash
(
self
.
scalar_op
)
^
hash
(
tuple_items
)
self
.
_rehash
()
def
__getstate__
(
self
):
d
=
copy
(
self
.
__dict__
)
d
.
pop
(
'ufunc'
)
d
.
pop
(
'__epydoc_asRoutine'
,
None
)
d
.
pop
(
'_hashval'
)
return
d
def
__setstate__
(
self
,
d
):
...
...
@@ -370,6 +379,7 @@ class Elemwise(Op):
self
.
ufunc
=
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
self
.
scalar_op
.
nin
,
self
.
scalar_op
.
nout
)
else
:
self
.
ufunc
=
None
self
.
_rehash
()
def
make_node
(
self
,
*
inputs
):
"""
...
...
@@ -429,6 +439,14 @@ class Elemwise(Op):
return
rval
return
False
def
_rehash
(
self
):
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
tuple_items
=
tuple
([
k
for
k
,
v
in
items
]
+
[(
tuple
(
v
)
if
isinstance
(
v
,
(
tuple
,
list
))
else
v
)
for
k
,
v
in
items
])
h
=
hash
(
'Elemwise'
)
^
hash
(
self
.
scalar_op
)
^
hash
(
tuple_items
)
assert
h
==
getattr
(
self
,
'_hashval'
,
h
)
self
.
_hashval
=
h
def
__hash__
(
self
):
return
self
.
_hashval
...
...
theano/tensor/nnet.py
浏览文件 @
ada92aea
...
...
@@ -94,6 +94,11 @@ class SoftmaxWithBias(gof.Op):
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
x
,
b
):
x
=
tensor
.
as_tensor_variable
(
x
)
b
=
tensor
.
as_tensor_variable
(
b
)
...
...
@@ -266,8 +271,9 @@ class SoftmaxGrad(gof.Op):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
)
)
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
dy
,
sm
,
**
kwargs
):
dy
=
tensor
.
as_tensor_variable
(
dy
)
...
...
@@ -437,6 +443,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
nout
=
3
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
x
,
b
,
y_idx
):
x
=
tensor
.
as_tensor_variable
(
x
)
...
...
@@ -608,6 +618,10 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
"""Gradient wrt x of the CrossentropySoftmax1Hot Op"""
def
__init__
(
self
,
**
kwargs
):
gof
.
Op
.
__init__
(
self
,
**
kwargs
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
dy
,
sm
,
y_idx
,
**
kwargs
):
dy
=
tensor
.
as_tensor_variable
(
dy
)
sm
=
tensor
.
as_tensor_variable
(
sm
)
...
...
@@ -728,7 +742,7 @@ class CrossentropyCategorical1HotGrad(gof.Op):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
)
)
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
g_y
,
coding_dist
,
true_one_of_n
):
return
gof
.
Apply
(
self
,
[
g_y
,
coding_dist
,
true_one_of_n
],
[
coding_dist
.
type
()])
def
perform
(
self
,
node
,
(
g_y
,
coding_dist
,
true_one_of_n
),
(
g_coding_strg
,)):
...
...
@@ -741,10 +755,6 @@ crossentropy_categorical_1hot_grad = CrossentropyCategorical1HotGrad()
class
CrossentropyCategorical1Hot
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
"""Compute the cross entropy between a coding distribution and
a true distribution of the form [0, 0, ... 0, 1, 0, ..., 0]
...
...
@@ -758,6 +768,11 @@ class CrossentropyCategorical1Hot(gof.Op):
Op will probably be optimized away in favour of one with a C implementation.
"""
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
coding_dist
,
true_one_of_n
):
"""
:type coding_dist: dense matrix
...
...
@@ -906,6 +921,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
val
=
scalar
.
constant
(
val
)
self
.
val
=
val
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
and
(
self
.
val
==
other
.
val
)
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
^
hash
(
self
.
val
.
value
)
def
make_node
(
self
,
mat
):
#check type of input
if
not
isinstance
(
mat
,
gof
.
Variable
)
or
not
mat
.
type
==
tensor
.
matrix
()
.
type
:
...
...
@@ -938,6 +958,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
return
goutput
[:,
1
:]
class
Prepend_scalar_to_each_row
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
tensor
.
hashtype
(
self
)
def
make_node
(
self
,
val
,
mat
):
#check type of input
if
isinstance
(
val
,
float
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论