Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2c125069
提交
2c125069
authored
6月 26, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
flake8 tensor/subtensor.py
上级
3c5b4282
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
23 行增加
和
26 行删除
+23
-26
subtensor.py
theano/tensor/subtensor.py
+23
-25
test_flake8.py
theano/tests/test_flake8.py
+0
-1
没有找到文件。
theano/tensor/subtensor.py
浏览文件 @
2c125069
from
copy
import
copy
from
copy
import
copy
import
os
import
sys
import
sys
from
textwrap
import
dedent
from
textwrap
import
dedent
import
warnings
import
warnings
import
logging
import
logging
_logger
=
logging
.
getLogger
(
"theano.tensor.subtensor"
)
import
numpy
import
numpy
from
six.moves
import
xrange
from
six.moves
import
xrange
...
@@ -32,6 +30,7 @@ if config.cxx:
...
@@ -32,6 +30,7 @@ if config.cxx:
except
ImportError
:
except
ImportError
:
pass
pass
_logger
=
logging
.
getLogger
(
"theano.tensor.subtensor"
)
# Do a lazy import of the sparse module
# Do a lazy import of the sparse module
sparse_module_ref
=
None
sparse_module_ref
=
None
...
@@ -336,9 +335,9 @@ class Subtensor(Op):
...
@@ -336,9 +335,9 @@ class Subtensor(Op):
theano
.
tensor
.
wscalar
,
theano
.
tensor
.
bscalar
]
theano
.
tensor
.
wscalar
,
theano
.
tensor
.
bscalar
]
invalid_tensor_types
=
[
theano
.
tensor
.
fscalar
,
theano
.
tensor
.
dscalar
,
invalid_tensor_types
=
[
theano
.
tensor
.
fscalar
,
theano
.
tensor
.
dscalar
,
theano
.
tensor
.
cscalar
,
theano
.
tensor
.
zscalar
]
theano
.
tensor
.
cscalar
,
theano
.
tensor
.
zscalar
]
if
(
isinstance
(
entry
,
gof
.
Variable
)
if
(
isinstance
(
entry
,
gof
.
Variable
)
and
and
(
entry
.
type
in
invalid_scal_types
(
entry
.
type
in
invalid_scal_types
or
or
entry
.
type
in
invalid_tensor_types
)):
entry
.
type
in
invalid_tensor_types
)):
raise
TypeError
(
"Expected an integer"
)
raise
TypeError
(
"Expected an integer"
)
if
isinstance
(
entry
,
gof
.
Variable
)
and
entry
.
type
in
scal_types
:
if
isinstance
(
entry
,
gof
.
Variable
)
and
entry
.
type
in
scal_types
:
...
@@ -346,13 +345,13 @@ class Subtensor(Op):
...
@@ -346,13 +345,13 @@ class Subtensor(Op):
elif
isinstance
(
entry
,
gof
.
Type
)
and
entry
in
scal_types
:
elif
isinstance
(
entry
,
gof
.
Type
)
and
entry
in
scal_types
:
return
entry
return
entry
if
(
isinstance
(
entry
,
gof
.
Variable
)
if
(
isinstance
(
entry
,
gof
.
Variable
)
and
and
entry
.
type
in
tensor_types
entry
.
type
in
tensor_types
and
and
numpy
.
all
(
entry
.
type
.
broadcastable
)):
numpy
.
all
(
entry
.
type
.
broadcastable
)):
return
scal
.
get_scalar_type
(
entry
.
type
.
dtype
)
return
scal
.
get_scalar_type
(
entry
.
type
.
dtype
)
elif
(
isinstance
(
entry
,
gof
.
Type
)
elif
(
isinstance
(
entry
,
gof
.
Type
)
and
and
entry
in
tensor_types
entry
in
tensor_types
and
and
numpy
.
all
(
entry
.
broadcastable
)):
numpy
.
all
(
entry
.
broadcastable
)):
return
scal
.
get_scalar_type
(
entry
.
dtype
)
return
scal
.
get_scalar_type
(
entry
.
dtype
)
elif
slice_ok
and
isinstance
(
entry
,
slice
):
elif
slice_ok
and
isinstance
(
entry
,
slice
):
a
=
entry
.
start
a
=
entry
.
start
...
@@ -425,8 +424,9 @@ class Subtensor(Op):
...
@@ -425,8 +424,9 @@ class Subtensor(Op):
conv
(
val
.
step
))
conv
(
val
.
step
))
else
:
else
:
try
:
try
:
return
get_scalar_constant_value
(
val
,
return
get_scalar_constant_value
(
only_process_constants
=
only_process_constants
)
val
,
only_process_constants
=
only_process_constants
)
except
theano
.
tensor
.
NotScalarConstantError
:
except
theano
.
tensor
.
NotScalarConstantError
:
if
allow_partial
:
if
allow_partial
:
return
val
return
val
...
@@ -477,8 +477,8 @@ class Subtensor(Op):
...
@@ -477,8 +477,8 @@ class Subtensor(Op):
%
(
input
.
type
,
expected_type
))
%
(
input
.
type
,
expected_type
))
# infer the broadcasting pattern
# infer the broadcasting pattern
padded
=
(
self
.
get_constant_idx
((
None
,)
+
inputs
,
allow_partial
=
True
)
padded
=
(
self
.
get_constant_idx
((
None
,)
+
inputs
,
allow_partial
=
True
)
+
+
[
slice
(
None
,
None
,
None
)]
*
(
x
.
type
.
ndim
-
len
(
idx_list
)))
[
slice
(
None
,
None
,
None
)]
*
(
x
.
type
.
ndim
-
len
(
idx_list
)))
broadcastable
=
[]
broadcastable
=
[]
for
i
,
(
p
,
bc
)
in
enumerate
(
izip
(
padded
,
x
.
type
.
broadcastable
)):
for
i
,
(
p
,
bc
)
in
enumerate
(
izip
(
padded
,
x
.
type
.
broadcastable
)):
if
isinstance
(
p
,
slice
):
if
isinstance
(
p
,
slice
):
...
@@ -528,9 +528,9 @@ class Subtensor(Op):
...
@@ -528,9 +528,9 @@ class Subtensor(Op):
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
):
# If it is the default (None, None, None) slice, or a variant,
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
# the shape will be xl
if
((
idx
.
start
in
[
None
,
0
])
if
((
idx
.
start
in
[
None
,
0
])
and
and
(
idx
.
stop
in
[
None
,
sys
.
maxsize
])
(
idx
.
stop
in
[
None
,
sys
.
maxsize
])
and
and
(
idx
.
step
is
None
or
idx
.
step
==
1
)):
(
idx
.
step
is
None
or
idx
.
step
==
1
)):
outshp
.
append
(
xl
)
outshp
.
append
(
xl
)
else
:
else
:
cnf
=
get_canonical_form_slice
(
idx
,
xl
)[
0
]
cnf
=
get_canonical_form_slice
(
idx
,
xl
)[
0
]
...
@@ -556,8 +556,7 @@ class Subtensor(Op):
...
@@ -556,8 +556,7 @@ class Subtensor(Op):
first
=
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)
first
=
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)
else
:
else
:
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
rest
)
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
rest
)
return
([
first
]
return
([
first
]
+
[
DisconnectedType
()()]
*
len
(
rest
))
+
[
DisconnectedType
()()]
*
len
(
rest
))
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
...
@@ -1034,8 +1033,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
...
@@ -1034,8 +1033,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
dim_offset
=
x
.
ndim
-
y
.
ndim
dim_offset
=
x
.
ndim
-
y
.
ndim
for
dim
in
xrange
(
y
.
ndim
):
for
dim
in
xrange
(
y
.
ndim
):
if
(
x
.
broadcastable
[
dim
+
dim_offset
]
if
(
x
.
broadcastable
[
dim
+
dim_offset
]
and
not
y
.
broadcastable
[
dim
]):
and
not
y
.
broadcastable
[
dim
]):
# It is acceptable to try to increment a subtensor with a
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# on that dimension. However, its length must then be 1.
...
@@ -2133,9 +2131,9 @@ class AdvancedIncSubtensor(Op):
...
@@ -2133,9 +2131,9 @@ class AdvancedIncSubtensor(Op):
return
hash
((
type
(
self
),
self
.
inplace
,
self
.
set_instead_of_inc
))
return
hash
((
type
(
self
),
self
.
inplace
,
self
.
set_instead_of_inc
))
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
return
(
type
(
self
)
==
type
(
other
)
and
and
self
.
inplace
==
other
.
inplace
self
.
inplace
==
other
.
inplace
and
and
self
.
set_instead_of_inc
==
other
.
set_instead_of_inc
)
self
.
set_instead_of_inc
==
other
.
set_instead_of_inc
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s{
%
s,
%
s}"
%
(
self
.
__class__
.
__name__
,
return
"
%
s{
%
s,
%
s}"
%
(
self
.
__class__
.
__name__
,
...
...
theano/tests/test_flake8.py
浏览文件 @
2c125069
...
@@ -57,7 +57,6 @@ whitelist_flake8 = [
...
@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py"
,
"typed_list/tests/test_type.py"
,
"typed_list/tests/test_opt.py"
,
"typed_list/tests/test_opt.py"
,
"typed_list/tests/test_basic.py"
,
"typed_list/tests/test_basic.py"
,
"tensor/subtensor.py"
,
"tensor/elemwise.py"
,
"tensor/elemwise.py"
,
"tensor/xlogx.py"
,
"tensor/xlogx.py"
,
"tensor/blas_headers.py"
,
"tensor/blas_headers.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论