Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
460f6b78
提交
460f6b78
authored
2月 21, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bugs fixed toward replacing NumpyR
上级
45f56231
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
279 行增加
和
57 行删除
+279
-57
core.py
core.py
+272
-53
lib.py
gof/lib.py
+5
-2
grad.py
grad.py
+2
-2
没有找到文件。
core.py
浏览文件 @
460f6b78
...
@@ -39,16 +39,18 @@ def as_string(*rs):
...
@@ -39,16 +39,18 @@ def as_string(*rs):
def
print_graph
(
*
rs
):
def
print_graph
(
*
rs
):
print
as_string
(
*
rs
)
print
as_string
(
*
rs
)
#useful mostly for unit tests
def
_approx_eq
(
a
,
b
,
eps
=
1.0e-9
):
def
_approx_eq
(
a
,
b
,
eps
=
1.0e-9
):
a
=
numpy
.
asarray
(
a
)
a
=
numpy
.
asarray
(
a
)
b
=
numpy
.
asarray
(
b
)
b
=
numpy
.
asarray
(
b
)
if
a
.
shape
!=
b
.
shape
:
if
a
.
shape
!=
b
.
shape
:
return
False
return
False
d
=
abs
(
a
-
b
)
return
numpy
.
max
(
numpy
.
abs
(
a
-
b
))
<
eps
return
numpy
.
all
(
d
<
eps
)
@blas._constant
# TODO: move this decorator to a utility script
# This function is only executed the first time it is called, subsequent calls
# return immediately from a cache of the first return value
@blas._constant
# TODO: move this decorator to a utility file
def
_compile_dir
():
def
_compile_dir
():
"""Return the directory in which scipy.weave should store code objects.
"""Return the directory in which scipy.weave should store code objects.
...
@@ -82,73 +84,256 @@ def _compile_dir():
...
@@ -82,73 +84,256 @@ def _compile_dir():
sys
.
path
.
append
(
cachedir
)
sys
.
path
.
append
(
cachedir
)
return
cachedir
return
cachedir
class
ResultBase
:
"""Base class for storing Op inputs and outputs
Attributes:
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
Properties:
role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
Abstract Methods:
data_filter
"""
class
BrokenLink
:
"""The owner of a Result that was replaced by another Result"""
__slots__
=
[
'old_role'
]
def
__init__
(
self
,
role
):
self
.
old_role
=
role
def
__nonzero__
(
self
):
return
False
class
BrokenLinkError
(
Exception
):
"""Exception thrown when an owner is a BrokenLink"""
class
AbstractFunction
(
Exception
):
"""Exception thrown when an abstract function is called"""
__slots__
=
[
'_role'
,
'_data'
,
'constant'
]
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
self
.
_role
=
role
self
.
constant
=
constant
if
data
is
None
:
#None is not filtered
self
.
_data
=
None
else
:
try
:
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
self
.
_data
=
data
#role is pair: (owner, outputs_position)
def
__get_role
(
self
):
return
self
.
_role
def
__set_role
(
self
,
role
):
owner
,
index
=
role
if
self
.
_role
is
not
None
:
# this is either an error or a no-op
_owner
,
_index
=
self
.
_role
if
_owner
is
not
owner
:
raise
ValueError
(
"Result
%
s already has an owner."
%
self
)
if
_index
!=
index
:
raise
ValueError
(
"Result
%
s was already mapped to a different index."
%
self
)
return
# because _owner is owner and _index == index
self
.
_role
=
role
role
=
property
(
__get_role
,
__set_role
)
#owner is role[0]
def
__get_owner
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
0
]
owner
=
property
(
__get_owner
,
doc
=
"Op of which this Result is an output, or None if role is None"
)
#index is role[1]
def
__get_index
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
1
]
index
=
property
(
__get_index
,
doc
=
"position of self in owner's outputs, or None if role is None"
)
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
def
__get_data
(
self
):
return
self
.
_data
def
__set_data
(
self
,
data
):
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
try
:
self
.
_data
=
self
.
data_filter
(
self
,
data
)
except
ResultBase
.
AbstractFunction
:
self
.
_data
=
data
data
=
property
(
__get_data
,
__set_data
,
doc
=
"The storage associated with this result"
)
def
data_filter
(
self
,
data
):
"""(abstract) Return an appropriate _data based on data."""
raise
ResultBase
.
AbstractFunction
()
# replaced
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
def
__set_replaced
(
self
,
replace
):
if
replace
==
self
.
replaced
:
return
if
replace
:
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
else
:
self
.
_role
=
self
.
_role
.
old_role
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
class
_test_ResultBase
(
unittest
.
TestCase
):
def
test_0
(
self
):
r
=
ResultBase
()
class
Numpy2
(
ResultBase
):
"""Result storing a numpy ndarray"""
__slots__
=
[
'_dtype'
,
'_shape'
,
]
class
ShapeUnknown
:
pass
# TODO: use this as the shape of uncomputed ndarrays of unknown shape
class
StateError
(
Exception
):
pass
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
if
isinstance
(
data
,
(
tuple
,
list
)):
# unallocated setup
shape
,
dtype
=
data
ResultBase
.
__init__
(
self
,
role
,
data
=
None
,
constant
=
constant
)
self
.
_shape
=
shape
self
.
_dtype
=
dtype
else
:
# allocated setup
ResultBase
.
__init__
(
self
,
role
,
data
,
constant
)
################################
# ResultBase
#
def
data_filter
(
self
,
data
):
#return numpy.asarray(data) #TODO: consider whether this is correct
if
isinstance
(
data
,
numpy
.
ndarray
):
return
data
raise
TypeError
(
'failed to filter data to ndarray'
,
data
)
################################
# Numpy2 specific functionality
#
__array__
=
property
(
lambda
self
:
self
.
_data
.
__array__
)
__array_struct__
=
property
(
lambda
self
:
self
.
_data
.
__array_struct__
)
def
data_set_inplace
(
self
,
data
):
raise
NotImplementedError
()
def
data_alloc
(
self
):
self
.
data
=
numpy
.
ndarray
(
self
.
shape
,
self
.
dtype
)
# self._dtype is used when self._data hasn't been set yet
def
__dtype_get
(
self
):
if
self
.
_data
is
None
:
return
self
.
_dtype
else
:
return
self
.
_data
.
dtype
def
__dtype_set
(
self
,
dtype
):
if
self
.
_data
is
None
:
self
.
_dtype
=
dtype
else
:
raise
StateError
(
'cannot set dtype after data has been set'
)
dtype
=
property
(
__dtype_get
,
__dtype_set
)
# self._shape is used when self._data hasn't been set yet
def
__shape_get
(
self
):
if
self
.
_data
is
None
:
return
self
.
_shape
else
:
return
self
.
_data
.
shape
def
__shape_set
(
self
,
shape
):
if
self
.
_data
is
None
:
self
.
_shape
=
shape
else
:
raise
StateError
(
'cannot set shape after data has been set'
)
shape
=
property
(
__shape_get
,
__shape_set
)
class
_test_Numpy2
(
unittest
.
TestCase
):
def
test_0
(
self
):
r
=
Numpy2
()
def
test_1
(
self
):
o
=
numpy
.
ones
((
3
,
3
))
r
=
Numpy2
(
data
=
o
)
self
.
failUnless
(
r
.
data
is
o
)
self
.
failUnless
(
r
.
shape
==
(
3
,
3
))
self
.
failUnless
(
str
(
r
.
dtype
)
==
'float64'
)
def
test_2
(
self
):
r
=
Numpy2
(
data
=
[(
3
,
3
),
'int32'
])
self
.
failUnless
(
r
.
data
is
None
)
self
.
failUnless
(
r
.
shape
==
(
3
,
3
))
self
.
failUnless
(
str
(
r
.
dtype
)
==
'int32'
)
r
.
data_alloc
()
self
.
failUnless
(
isinstance
(
r
.
data
,
numpy
.
ndarray
))
self
.
failUnless
(
r
.
shape
==
(
3
,
3
))
self
.
failUnless
(
str
(
r
.
dtype
)
==
'int32'
)
def
input
(
x
):
def
input
(
x
):
#
NB:
#
static member initialization
# - automatically casting int to float seems wrong.
if
not
hasattr
(
input
,
'float_dtype'
):
# - we want to be able to write y = x + 1 and maybe have the 1 casted to 1.0
input
.
float_dtype
=
'float64'
# at some point to maximize speed right?
input
.
int_dtype
=
'int64'
# - But more important is the ability to store index values without them
input
.
NN
=
NumpyR
# being cast to floating-point (can that cause incorrectness?)
if
isinstance
(
x
,
numpy
.
ndarray
):
if
isinstance
(
x
,
numpy
.
ndarray
):
return
NumpyR
(
x
)
return
input
.
NN
(
x
)
elif
isinstance
(
x
,
int
):
elif
isinstance
(
x
,
int
):
z
=
numpy
.
zeros
((),
dtype
=
input
.
int_dtype
)
z
=
numpy
.
zeros
((),
dtype
=
input
.
int_dtype
)
z
+=
x
z
+=
x
return
NumpyR
(
z
)
return
input
.
NN
(
z
)
elif
isinstance
(
x
,
float
):
elif
isinstance
(
x
,
float
):
z
=
numpy
.
zeros
((),
dtype
=
input
.
float_dtype
)
z
=
numpy
.
zeros
((),
dtype
=
input
.
float_dtype
)
z
+=
x
z
+=
x
return
NumpyR
(
z
)
return
input
.
NN
(
z
)
elif
isinstance
(
x
,
gof
.
Result
):
elif
isinstance
(
x
,
gof
.
Result
):
raise
TypeError
(
"
%
s is already a result."
%
x
)
raise
TypeError
(
"
%
s is already a result."
%
x
)
else
:
else
:
return
ResultValue
(
x
)
return
ResultValue
(
x
)
class
_testCase_input
(
unittest
.
TestCase
):
input
.
float_dtype
=
'float64'
def
setUp
(
self
):
input
.
int_dtype
=
'int64'
literal
.
hdb
=
{}
literal
.
udb
=
{}
def
test_input_int
(
self
):
w
=
input
(
3
)
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
w
.
data
==
3
)
def
test_input_float
(
self
):
w
=
input
(
3.0
)
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
w
.
data
==
3.0
)
def
wrap
(
x
):
def
wrap
(
x
):
if
isinstance
(
x
,
NumpyR
):
if
isinstance
(
x
,
NumpyR
):
return
x
return
x
elif
isinstance
(
x
,
Numpy2
):
return
x
elif
isinstance
(
x
,
ResultValue
):
elif
isinstance
(
x
,
ResultValue
):
return
x
return
x
elif
isinstance
(
x
,
omega_op
):
elif
isinstance
(
x
,
omega_op
):
return
x
.
out
return
x
.
out
else
:
else
:
return
literal
(
x
)
return
literal
(
x
)
class
_testCase_wrap
(
unittest
.
TestCase
):
class
_testCase_wrap
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
literal
.
hdb
=
{}
literal
.
hdb
=
{}
literal
.
udb
=
{}
literal
.
udb
=
{}
def
test_input_int
(
self
):
w
=
input
(
3
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
w
.
data
==
3
)
def
test_input_float
(
self
):
w
=
input
(
3.0
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
w
.
data
==
3.0
)
def
test_literal_int
(
self
):
w
=
literal
(
3
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
w
.
data
==
3
)
def
test_literal_float
(
self
):
w
=
literal
(
3.0
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
w
.
data
==
3.0
)
def
test_wrap_int
(
self
):
def
test_wrap_int
(
self
):
w
=
wrap
(
3
)
w
=
wrap
(
3
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
w
.
data
==
3
)
self
.
failUnless
(
w
.
data
==
3
)
def
test_wrap_float
(
self
):
def
test_wrap_float
(
self
):
w
=
wrap
(
3.0
)
w
=
wrap
(
3.0
)
self
.
failUnless
(
isinstance
(
w
,
NumpyR
))
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
w
.
data
==
3.0
)
self
.
failUnless
(
w
.
data
==
3.0
)
...
@@ -168,7 +353,7 @@ def literal(x):
...
@@ -168,7 +353,7 @@ def literal(x):
if
_hashable
(
x
):
if
_hashable
(
x
):
db
=
literal
.
hdb
db
=
literal
.
hdb
key
=
(
id
(
x
),
x
)
key
=
(
type
(
x
),
x
)
else
:
else
:
db
=
literal
.
udb
db
=
literal
.
udb
key
=
(
id
(
x
),)
key
=
(
id
(
x
),)
...
@@ -180,6 +365,33 @@ def literal(x):
...
@@ -180,6 +365,33 @@ def literal(x):
rval
.
constant
=
True
rval
.
constant
=
True
db
[
key
]
=
rval
db
[
key
]
=
rval
return
rval
return
rval
class
_testCase_literal
(
unittest
.
TestCase
):
def
setUp
(
self
):
literal
.
hdb
=
{}
literal
.
udb
=
{}
def
test_int
(
self
):
w
=
literal
(
3
)
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
int_dtype
)
self
.
failUnless
(
w
.
data
==
3
)
u
=
literal
(
1
+
2
)
self
.
failUnless
(
u
is
w
)
def
test_float
(
self
):
w
=
literal
(
3.0
)
self
.
failUnless
(
isinstance
(
w
,
input
.
NN
))
self
.
failUnless
(
str
(
w
.
data
.
dtype
)
==
input
.
float_dtype
)
self
.
failUnless
(
w
.
data
==
3.0
)
u
=
literal
(
1.0
+
2.0
)
self
.
failUnless
(
u
is
w
)
def
test_mixed
(
self
):
f
=
literal
(
2.0
)
i
=
literal
(
2
)
self
.
failUnless
(
i
is
not
f
)
...
@@ -757,6 +969,9 @@ class NumpyR(gof.ResultValue):
...
@@ -757,6 +969,9 @@ class NumpyR(gof.ResultValue):
def
__getslice__
(
self
,
*
args
):
return
get_slice
(
self
,
slice
(
*
args
))
def
__getslice__
(
self
,
*
args
):
return
get_slice
(
self
,
slice
(
*
args
))
def
wrap_producer
(
f
):
def
wrap_producer
(
f
):
class
producer
(
omega_op
):
class
producer
(
omega_op
):
impl
=
f
impl
=
f
...
@@ -1075,7 +1290,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1075,7 +1290,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'
matrices are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'
objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
...
@@ -1085,7 +1300,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1085,7 +1300,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'
matrices are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'
objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_1_3
(
self
):
def
test_dot_fail_1_3
(
self
):
...
@@ -1094,7 +1309,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1094,7 +1309,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_2_1
(
self
):
def
test_dot_fail_2_1
(
self
):
...
@@ -1103,7 +1318,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1103,7 +1318,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'
matrices are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'
objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_2_2
(
self
):
def
test_dot_fail_2_2
(
self
):
...
@@ -1112,7 +1327,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1112,7 +1327,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'
matrices are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'
objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_2_3
(
self
):
def
test_dot_fail_2_3
(
self
):
...
@@ -1121,7 +1336,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1121,7 +1336,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_3_1
(
self
):
def
test_dot_fail_3_1
(
self
):
...
@@ -1130,7 +1345,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1130,7 +1345,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_3_2
(
self
):
def
test_dot_fail_3_2
(
self
):
...
@@ -1139,7 +1354,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1139,7 +1354,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
def
test_dot_fail_3_3
(
self
):
def
test_dot_fail_3_3
(
self
):
...
@@ -1148,7 +1363,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1148,7 +1363,7 @@ class _testCase_dot(unittest.TestCase):
try
:
try
:
z
=
dot
(
x
,
y
)
z
=
dot
(
x
,
y
)
except
ValueError
,
e
:
except
ValueError
,
e
:
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
)
self
.
failUnless
(
str
(
e
)
==
'objects are not aligned'
,
e
)
return
return
self
.
fail
()
self
.
fail
()
...
@@ -1425,15 +1640,19 @@ class _testCase_power(unittest.TestCase):
...
@@ -1425,15 +1640,19 @@ class _testCase_power(unittest.TestCase):
numpy
.
random
.
seed
(
44
)
numpy
.
random
.
seed
(
44
)
def
tearDown
(
self
):
def
tearDown
(
self
):
pop_mode
()
pop_mode
()
def
test_0
(
self
):
def
test_0
(
self
):
r
=
numpy
.
random
.
rand
(
50
)
r
=
numpy
.
random
.
rand
(
50
)
er
=
exp
(
r
)
ler
=
log
(
er
)
a
,
b
=
numpy
.
max
(
ler
-
r
),
numpy
.
min
(
ler
-
r
)
exp_r
=
exp
(
r
)
self
.
failUnless
(
a
<
1.0e-13
and
b
>
-
1.0e-13
,
'exp and log are not inverses'
)
self
.
failUnless
(
_approx_eq
(
exp_r
,
numpy
.
exp
(
r
))
)
log_exp_r
=
log
(
exp_r
)
self
.
failUnless
(
_approx_eq
(
log_exp_r
,
r
))
def
test_1
(
self
):
r
=
numpy
.
random
.
rand
(
50
)
r2
=
pow
(
r
,
2
)
self
.
failUnless
(
_approx_eq
(
r2
,
r
*
r
))
## Others ##
## Others ##
...
...
gof/lib.py
浏览文件 @
460f6b78
...
@@ -58,6 +58,10 @@ def compute(*nodes):
...
@@ -58,6 +58,10 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from
(
nodes
,
set
())
compute_from
(
nodes
,
set
())
def
is_result
(
obj
):
"""Return True iff obj provides the interface of a Result"""
attr_list
=
'data'
,
'owner'
return
all
([
hasattr
(
obj
,
attr
)
for
attr
in
attr_list
])
class
ForbidConstantOverwrite
(
features
.
Listener
,
features
.
Constraint
):
class
ForbidConstantOverwrite
(
features
.
Listener
,
features
.
Constraint
):
...
@@ -460,8 +464,7 @@ class PythonOp(Op):
...
@@ -460,8 +464,7 @@ class PythonOp(Op):
Op
.
__init__
(
self
,
inputs
,
self
.
gen_outputs
())
Op
.
__init__
(
self
,
inputs
,
self
.
gen_outputs
())
def
__validate__
(
self
):
def
__validate__
(
self
):
for
input
in
self
.
inputs
:
return
all
([
is_result
(
i
)
for
i
in
self
.
inputs
])
assert
isinstance
(
input
,
ResultValue
)
def
gen_outputs
(
self
):
def
gen_outputs
(
self
):
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
...
...
grad.py
浏览文件 @
460f6b78
import
gof
import
gof
from
gof.lib
import
compute_from
from
gof.lib
import
compute_from
,
is_result
import
core
import
core
class
Grad
(
object
):
class
Grad
(
object
):
...
@@ -173,7 +173,7 @@ class update_gradient_via_grad:
...
@@ -173,7 +173,7 @@ class update_gradient_via_grad:
"""
"""
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
if
len
(
self
.
inputs
)
==
1
and
is
instance
(
inputgs
,
gof
.
ResultValue
):
if
len
(
self
.
inputs
)
==
1
and
is
_result
(
inputgs
):
inputgs
=
[
inputgs
]
inputgs
=
[
inputgs
]
else
:
else
:
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论