Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0f344ee4
提交
0f344ee4
authored
3月 19, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
created BaseTensorOp in base_tensor
上级
ca315a08
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
123 行增加
和
63 行删除
+123
-63
_test_tensor.py
_test_tensor.py
+5
-5
base_tensor.py
base_tensor.py
+66
-1
elemwise.py
elemwise.py
+1
-1
_test_opt.py
gof/_test_opt.py
+15
-1
tensor.py
tensor.py
+36
-55
没有找到文件。
_test_tensor.py
浏览文件 @
0f344ee4
...
...
@@ -101,15 +101,15 @@ class T_abs(unittest.TestCase):
class
T_fill
(
unittest
.
TestCase
):
def
test0
(
self
):
t
=
fill
(
numpy
.
asarray
([
1
,
2
,
3
]),
9
.0
)
t
=
fill
(
numpy
.
asarray
([
1
,
2
,
3
]),
9
)
self
.
failUnless
(
t
.
owner
.
__class__
==
Fill
)
o
=
t
.
owner
self
.
failUnless
(
o
.
inputs
[
0
]
.
broadcastable
==
(
0
,))
self
.
failUnless
(
o
.
inputs
[
0
]
.
dtype
[
0
:
3
]
==
'int'
)
#
self.failUnless(o.inputs[0].dtype[0:3] == 'int')
self
.
failUnless
(
o
.
inputs
[
1
]
.
broadcastable
==
())
self
.
failUnless
(
o
.
inputs
[
1
]
.
dtype
[
0
:
3
]
==
'flo'
)
#
self.failUnless(o.inputs[1].dtype[0:3] == 'flo')
self
.
failUnless
(
o
.
outputs
[
0
]
.
broadcastable
==
(
0
,))
self
.
failUnless
(
o
.
outputs
[
0
]
.
dtype
[
0
:
3
]
==
'flo'
)
#
self.failUnless(o.outputs[0].dtype[0:3] == 'flo')
class
T_sum
(
unittest
.
TestCase
):
def
test_impl
(
self
):
...
...
@@ -152,7 +152,7 @@ class T_mul(unittest.TestCase):
def
test_operator
(
self
):
a
=
tinit
([
1
,
1
])
aa
=
tinit
([
1
,
1
])
b
=
tinit
(
4
.0
)
b
=
tinit
(
4
)
self
.
failUnless
(
isinstance
((
a
*
b
)
.
owner
,
Scale
))
self
.
failUnless
(
isinstance
((
b
*
a
)
.
owner
,
Scale
))
self
.
failUnless
(
isinstance
((
a
*
aa
)
.
owner
,
MulElemwise
))
...
...
base_tensor.py
浏览文件 @
0f344ee4
"""A simple class to store ndarray data """
from
gof
import
ResultBase
from
gof
import
ResultBase
,
Op
,
utils
import
numpy
from
copy
import
copy
...
...
@@ -194,3 +194,68 @@ class BaseTensor(ResultBase):
class
BaseTensorOp
(
Op
):
"""
A basic Op subclass that can be used to make Ops that operate on Tensors.
It is not mandatory to inherit from this class, but it is practical.
BasicTensorOp is parametrized as follows:
* nin: number of inputs
* nout: number of outputs
* out_tensor_class: BaseTensor subclass used to instantiate the outputs
* input_wrapper: returns a Tensor from its argument
* propagate_dtype: returns a list of dtypes corresponding to the
output dtypes from a list of input dtypes (if an input is
not a Tensor, the passed value will be None)
* propagate_broadcastable: returns a list of tuples corresponding to
the output broadcastable flags from the input broadcastable
flags (if an input is not a Tensor, the passed value will be
None).
"""
nin
=
-
1
# nin == -1 means: arbitrary number of inputs
nout
=
1
out_tensor_class
=
BaseTensor
@classmethod
def
input_wrapper
(
cls
,
obj
):
"""
Returns a Result from an arbitrary-typed input, if possible.
"""
if
isinstance
(
obj
,
BaseResult
):
return
obj
else
:
raise
TypeError
(
"Expected a Result instance."
)
def
__init__
(
self
,
*
inputs
):
inputs
=
map
(
self
.
input_wrapper
,
inputs
)
if
self
.
nin
>=
0
:
if
len
(
inputs
)
!=
self
.
nin
:
raise
TypeError
(
"Wrong number of inputs for
%
s (got
%
i, expected
%
i)"
)
\
%
(
self
,
len
(
inputs
),
self
.
nin
)
i_broadcastables
=
[
getattr
(
input
,
'broadcastable'
,
None
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
o_broadcastables
=
utils
.
from_return_values
(
self
.
propagate_broadcastable
(
*
i_broadcastables
))
o_dtypes
=
utils
.
from_return_values
(
self
.
propagate_dtype
(
*
i_dtypes
))
self
.
inputs
=
inputs
self
.
outputs
=
[
self
.
out_tensor_class
(
dtype
,
broadcastable
)
for
broadcastable
,
dtype
in
zip
(
o_broadcastables
,
o_dtypes
)]
def
propagate_broadcastable
(
self
,
*
inputs
):
raise
AbstractFunctionError
()
def
propagate_dtype
(
self
,
*
i_dtypes
):
rval
=
set
([
dtype
for
dtype
in
i_dtypes
if
dtype
is
not
None
])
if
len
(
rval
)
==
0
:
raise
ValueError
(
"Cannot infer the dtypes of the outputs with no Tensor inputs."
)
elif
len
(
rval
)
>
1
:
raise
ValueError
(
"The dtypes of all inputs should be identical."
)
return
[
rval
.
pop
()]
*
self
.
nout
elemwise.py
浏览文件 @
0f344ee4
...
...
@@ -78,7 +78,7 @@ class Elemwise(Op):
return
code_cleanup
@classmethod
def
inplace_version
(
cls
):
def
inplace_version
(
cls
,
dmap
=
{
0
:
0
}
):
class
Ret
(
cls
,
Destroyer
):
def
destroy_map
(
self
):
return
{
self
.
outputs
[
0
]:
[
self
.
inputs
[
0
]]}
...
...
gof/_test_opt.py
浏览文件 @
0f344ee4
...
...
@@ -165,7 +165,21 @@ class _test_MergeOptimizer(unittest.TestCase):
assert
str
(
g
)
==
"[Op1(*1 -> Op2(x, y), *1, *1)]"
\
or
str
(
g
)
==
"[Op1(*1 -> Op2(x, z), *1, *1)]"
def
test_2
(
self
):
class
_test_ConstantFinder
(
unittest
.
TestCase
):
def
test_0
(
self
):
x
,
y
,
z
=
inputs
()
y
.
data
=
2
z
.
data
=
2
e
=
op1
(
x
,
y
,
z
)
g
=
env
([
x
],
[
e
])
ConstantFinder
()
.
optimize
(
g
)
MergeOptimizer
()
.
optimize
(
g
)
assert
str
(
g
)
==
"[Op1(x, y, y)]"
\
or
str
(
g
)
==
"[Op1(x, z, z)]"
def
test_1
(
self
):
x
,
y
,
z
=
inputs
()
y
.
data
=
2
z
.
data
=
2
...
...
tensor.py
浏览文件 @
0f344ee4
...
...
@@ -5,7 +5,7 @@ from copy import copy
import
inspect
from
gof
import
ResultBase
,
Op
,
utils
,
Destroyer
,
Viewer
,
AbstractFunctionError
from
base_tensor
import
BaseTensor
from
base_tensor
import
BaseTensor
,
BaseTensorOp
from
elemwise
import
Elemwise
...
...
@@ -114,64 +114,45 @@ def _assert_tensor_scalar(x, a):
raise
ValueError
(
"The second argument must be a scalar."
)
class
_Op
(
Op
):
class
_Op
(
BaseTensor
Op
):
"""A convenient base for the ops in this file"""
nin
=
-
1
nout
=
1
_destroy_map
=
{}
out_tensor_class
=
Tensor
def
__init__
(
self
,
*
inputs
):
def
as_tensor
(
obj
):
if
isinstance
(
obj
,
Tensor
):
return
obj
else
:
return
tinit
(
obj
)
inputs
=
map
(
as_tensor
,
inputs
)
if
self
.
nin
>=
0
:
if
len
(
inputs
)
!=
self
.
nin
:
raise
TypeError
(
"Wrong number of inputs for
%
s (got
%
i, expected
%
i)"
)
\
%
(
self
,
len
(
inputs
),
self
.
nin
)
i_broadcastables
=
[
getattr
(
input
,
'broadcastable'
,
None
)
for
input
in
inputs
]
i_dtypes
=
[
getattr
(
input
,
'dtype'
,
None
)
for
input
in
inputs
]
o_broadcastables
=
utils
.
from_return_values
(
self
.
propagate_broadcastable
(
*
i_broadcastables
))
o_dtypes
=
utils
.
from_return_values
(
self
.
propagate_dtype
(
*
i_dtypes
))
self
.
inputs
=
inputs
self
.
outputs
=
[
Tensor
(
dtype
,
broadcastable
)
for
broadcastable
,
dtype
in
zip
(
o_broadcastables
,
o_dtypes
)]
def
propagate_broadcastable
(
self
,
*
inputs
):
raise
AbstractFunctionError
()
@classmethod
def
input_wrapper
(
cls
,
obj
):
if
isinstance
(
obj
,
Tensor
):
return
obj
else
:
return
tinit
(
obj
)
# nin = -1
# nout = 1
def
propagate_dtype
(
self
,
*
i_dtypes
):
def
upcast
(
dtype
,
*
dtypes
):
z
=
numpy
.
zeros
((),
dtype
=
dtype
)
for
dtype
in
dtypes
:
z
=
z
+
numpy
.
zeros
((),
dtype
=
dtype
)
return
str
(
z
.
dtype
)
for
dtype
in
i_dtypes
:
if
dtype
is
None
:
raise
TypeError
(
"Expected a Tensor."
)
upcasted
=
upcast
(
*
i_dtypes
)
return
[
upcasted
]
*
self
.
nout
# try:
# dmap = self.destroy_map()
# except AttributeError:
# dmap = {}
# rval = []
# for i in xrange(self.nout):
# if i in dmap:
# destroyed = dmap[output]
# if len(destroyed) != 1:
# raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# rval.append(destroyed[0])
# else:
# rval.append(upcasted)
# return rval
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
def
impl
(
self
,
*
inputs
):
raise
AbstractFunctionError
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论