Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e926c476
提交
e926c476
authored
10月 30, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adding comments and TODOs to ShapeFeature
上级
3540ba9e
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
35 行增加
和
8 行删除
+35
-8
opt.py
theano/tensor/opt.py
+35
-8
没有找到文件。
theano/tensor/opt.py
浏览文件 @
e926c476
...
@@ -629,18 +629,23 @@ class ShapeFeature(object):
...
@@ -629,18 +629,23 @@ class ShapeFeature(object):
"""
"""
def
shape_ir
(
self
,
i
,
r
):
def
shape_ir
(
self
,
i
,
r
):
#TODO: Write a doc string for this method
"""Return symbolic r.shape[i] for tensor variable r, int i"""
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
return
self
.
lscalar_one
return
self
.
lscalar_one
else
:
else
:
return
Shape_i
(
i
)
.
make_node
(
r
)
.
outputs
[
0
]
return
Shape_i
(
i
)
.
make_node
(
r
)
.
outputs
[
0
]
def
shape_tuple
(
self
,
r
):
def
shape_tuple
(
self
,
r
):
#TODO: Write a doc string for this method
"""Return a tuple of symbolic shape vars for tensor variable r"""
return
tuple
([
self
.
shape_ir
(
i
,
r
)
for
i
in
xrange
(
r
.
ndim
)])
return
tuple
([
self
.
shape_ir
(
i
,
r
)
for
i
in
xrange
(
r
.
ndim
)])
def
default_infer_shape
(
self
,
node
,
i_shapes
):
def
default_infer_shape
(
self
,
node
,
i_shapes
):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval
=
[]
rval
=
[]
for
r
in
node
.
outputs
:
for
r
in
node
.
outputs
:
try
:
try
:
...
@@ -650,16 +655,21 @@ class ShapeFeature(object):
...
@@ -650,16 +655,21 @@ class ShapeFeature(object):
return
rval
return
rval
def
unpack
(
self
,
s_i
):
def
unpack
(
self
,
s_i
):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
"""
# unpack the s_i that the Op returned
# unpack the s_i that the Op returned
assert
s_i
is
not
None
assert
s_i
is
not
None
if
s_i
==
1
:
if
s_i
==
1
:
# don't make the optimizer merge a zillion ones together
# don't make the optimizer merge a zillion ones together
# by always returning the same object to represent 1
return
self
.
lscalar_one
return
self
.
lscalar_one
if
type
(
s_i
)
in
(
int
,
long
)
or
isinstance
(
s_i
,
numpy
.
integer
):
if
type
(
s_i
)
in
(
int
,
long
)
or
isinstance
(
s_i
,
numpy
.
integer
):
# this shape is a constant
# this shape is a constant
assert
s_i
>=
0
assert
s_i
>=
0
return
T
.
constant
(
s_i
,
dtype
=
'int64'
)
return
T
.
constant
(
s_i
,
dtype
=
'int64'
)
if
type
(
s_i
)
in
(
tuple
,
list
):
if
type
(
s_i
)
in
(
tuple
,
list
):
# this dimension is the same as many of the inputs
# this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known,
# which tells us that if one of the inputs is known,
# the others all become known.
# the others all become known.
...
@@ -676,12 +686,19 @@ class ShapeFeature(object):
...
@@ -676,12 +686,19 @@ class ShapeFeature(object):
s_i
,
type
(
s_i
),
getattr
(
s_i
,
'type'
,
None
))
s_i
,
type
(
s_i
),
getattr
(
s_i
,
'type'
,
None
))
def
set_shape
(
self
,
r
,
s
):
def
set_shape
(
self
,
r
,
s
):
"""Assign the shape `s` to previously un-shaped variable `r`.
:type r: a variable
:type s: None or a tuple of symbolic integers
"""
assert
r
not
in
self
.
shape_of
,
'r already in shape_of'
assert
r
not
in
self
.
shape_of
,
'r already in shape_of'
if
s
is
None
:
if
s
is
None
:
self
.
shape_of
[
r
]
=
s
self
.
shape_of
[
r
]
=
s
else
:
else
:
self
.
shape_of
[
r
]
=
tuple
([
self
.
unpack
(
s_i
)
for
s_i
in
s
])
self
.
shape_of
[
r
]
=
tuple
([
self
.
unpack
(
s_i
)
for
s_i
in
s
])
# XXX: add a reverse index from the tuple elements -> r
def
update_shape
(
self
,
r
,
other_r
):
def
update_shape
(
self
,
r
,
other_r
):
'''Replace shape of r by shape of other_r.
'''Replace shape of r by shape of other_r.
...
@@ -697,10 +714,14 @@ class ShapeFeature(object):
...
@@ -697,10 +714,14 @@ class ShapeFeature(object):
else
:
else
:
# If no info is known on r's shape, use other_shape
# If no info is known on r's shape, use other_shape
self
.
shape_of
[
r
]
=
other_shape
self
.
shape_of
[
r
]
=
other_shape
#XXX: add reverse index from elements of other_shape -> r
return
return
# If other_shape has no information, use r_shape
# If other_shape has no information, call is pointless.
# XXX: move this above the previous if/else block
if
other_shape
is
None
:
if
other_shape
is
None
:
# XXX: no need to assign back, delete following line
self
.
shape_of
[
r
]
=
r_shape
self
.
shape_of
[
r
]
=
r_shape
return
return
...
@@ -719,6 +740,7 @@ class ShapeFeature(object):
...
@@ -719,6 +740,7 @@ class ShapeFeature(object):
else
:
else
:
merged_shape
.
append
(
other_shape
[
i
])
merged_shape
.
append
(
other_shape
[
i
])
self
.
shape_of
[
r
]
=
tuple
(
merged_shape
)
self
.
shape_of
[
r
]
=
tuple
(
merged_shape
)
# XXX: update reverse index
def
set_shape_i
(
self
,
r
,
i
,
s_i
):
def
set_shape_i
(
self
,
r
,
i
,
s_i
):
'''Replace element i of shape_of[r] by s_i'''
'''Replace element i of shape_of[r] by s_i'''
...
@@ -733,13 +755,15 @@ class ShapeFeature(object):
...
@@ -733,13 +755,15 @@ class ShapeFeature(object):
else
:
else
:
new_shape
.
append
(
s_j
)
new_shape
.
append
(
s_j
)
self
.
shape_of
[
r
]
=
tuple
(
new_shape
)
self
.
shape_of
[
r
]
=
tuple
(
new_shape
)
# XXX: update reverse index
def
init_r
(
self
,
r
):
def
init_r
(
self
,
r
):
'''Register r's shape in the shape_of dictionary.'''
'''Register r's shape in the shape_of dictionary.'''
if
r
not
in
self
.
shape_of
:
if
r
not
in
self
.
shape_of
:
try
:
try
:
self
.
set_shape
(
r
,
self
.
shape_tuple
(
r
))
self
.
set_shape
(
r
,
self
.
shape_tuple
(
r
))
except
AttributeError
:
# XXX: update reverse index
except
AttributeError
:
#XXX: where would this come from?
self
.
set_shape
(
r
,
None
)
self
.
set_shape
(
r
,
None
)
def
make_vector_shape
(
self
,
r
):
def
make_vector_shape
(
self
,
r
):
...
@@ -759,6 +783,7 @@ class ShapeFeature(object):
...
@@ -759,6 +783,7 @@ class ShapeFeature(object):
self
.
shape_of
=
{}
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
shape_of
=
{}
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
scheduled
=
{}
# Variable ->
self
.
scheduled
=
{}
# Variable ->
# XXX: create reverse index
for
node
in
env
.
toposort
():
for
node
in
env
.
toposort
():
self
.
on_import
(
env
,
node
)
self
.
on_import
(
env
,
node
)
...
@@ -798,9 +823,11 @@ class ShapeFeature(object):
...
@@ -798,9 +823,11 @@ class ShapeFeature(object):
# this is packed information
# this is packed information
# an element of o_shapes is either None or a tuple
# an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints
# elements of the tuple can be either strings, or ints
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
raise
Exception
(
'len(o_shapes) = '
+
str
(
len
(
o_shapes
))
+
' != len(node.outputs) = '
+
str
(
len
(
node
.
outputs
)))
raise
Exception
(
'len(o_shapes) = '
+
str
(
len
(
o_shapes
))
+
' != len(node.outputs) = '
+
str
(
len
(
node
.
outputs
)))
for
r
,
s
in
zip
(
node
.
outputs
,
o_shapes
):
for
r
,
s
in
zip
(
node
.
outputs
,
o_shapes
):
self
.
set_shape
(
r
,
s
)
self
.
set_shape
(
r
,
s
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论