Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0e35cc21
提交
0e35cc21
authored
3月 28, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1299 from delallea/canonical_slice
Simpler graphs for canonical slices
上级
ef158c39
cb47093e
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
140 行增加
和
31 行删除
+140
-31
basic.py
theano/tensor/basic.py
+128
-31
test_basic.py
theano/tensor/tests/test_basic.py
+12
-0
没有找到文件。
theano/tensor/basic.py
浏览文件 @
0e35cc21
...
@@ -510,7 +510,7 @@ def get_scalar_constant_value(v):
...
@@ -510,7 +510,7 @@ def get_scalar_constant_value(v):
if
isinstance
(
v
,
(
numpy
.
integer
,
int
,
float
)):
if
isinstance
(
v
,
(
numpy
.
integer
,
int
,
float
)):
return
numpy
.
asarray
(
v
)
return
numpy
.
asarray
(
v
)
def
numpy_scalar
(
n
):
def
numpy_scalar
(
data
):
""" Return a scalar stored in a numpy ndarray, or raise
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
NotScalarConstantError if the numpy ndarray is not a scalar
"""
"""
...
@@ -526,7 +526,7 @@ def get_scalar_constant_value(v):
...
@@ -526,7 +526,7 @@ def get_scalar_constant_value(v):
except
Exception
:
except
Exception
:
raise
NotScalarConstantError
(
raise
NotScalarConstantError
(
'v.data is non-numeric, non-scalar, or has more than one'
'v.data is non-numeric, non-scalar, or has more than one'
' unique value'
,
n
)
' unique value'
,
data
)
if
isinstance
(
v
,
numpy
.
ndarray
):
if
isinstance
(
v
,
numpy
.
ndarray
):
return
numpy_scalar
(
v
)
return
numpy_scalar
(
v
)
...
@@ -4250,27 +4250,114 @@ def get_canonical_form_slice(theslice, length):
...
@@ -4250,27 +4250,114 @@ def get_canonical_form_slice(theslice, length):
that respects the conventions imposed by python and numpy.
that respects the conventions imposed by python and numpy.
In a canonical form a slice is represented by a canonical form slice,
In a canonical form a slice is represented by a canonical form slice,
in which
the start <= stop and step >0 and a flag which says if the
in which
0 <= start <= stop <= length and step > 0, and a flag which says
resulting set of numbers needs to be reversed or not.
if the
resulting set of numbers needs to be reversed or not.
'''
'''
if
isinstance
(
theslice
,
slice
):
if
isinstance
(
theslice
,
slice
):
start
=
extract_constant
(
theslice
.
start
)
def
analyze
(
x
):
stop
=
extract_constant
(
theslice
.
stop
)
try
:
step
=
extract_constant
(
theslice
.
step
)
x_constant
=
get_scalar_constant_value
(
x
)
is_constant
=
True
except
NotScalarConstantError
:
x_constant
=
extract_constant
(
x
)
is_constant
=
False
return
x_constant
,
is_constant
start
,
is_start_constant
=
analyze
(
theslice
.
start
)
stop
,
is_stop_constant
=
analyze
(
theslice
.
stop
)
step
,
is_step_constant
=
analyze
(
theslice
.
step
)
length
,
is_length_constant
=
analyze
(
length
)
if
step
is
None
:
if
step
is
None
:
step
=
1
step
=
1
defstart
=
switch
(
lt
(
step
,
0
),
(
length
-
1
),
0
)
# First handle the easier and common case where `step` is 1 and
defstop
=
switch
(
lt
(
step
,
0
),
-
1
,
length
)
# either `start` or `stop` is a range boundary. More specializations
# could be added later. This makes the resulting graph smaller than
# in the generic case below.
if
step
==
1
:
is_start_0
=
(
start
in
[
None
,
0
]
or
(
is_start_constant
and
is_length_constant
and
start
<
0
and
start
+
length
<=
0
))
is_stop_length
=
(
stop
in
[
None
,
length
,
maxsize
]
or
(
is_stop_constant
and
is_length_constant
and
stop
>=
length
))
if
is_start_0
:
# 0:stop:1
if
is_stop_length
:
# Full slice.
return
slice
(
0
,
length
,
1
),
1
if
is_stop_constant
and
stop
>=
0
:
return
(
slice
(
0
,
switch
(
lt
(
stop
,
length
),
stop
,
length
),
1
),
1
)
stop_plus_len
=
stop
+
length
stop
=
switch
(
lt
(
stop
,
0
),
# stop < 0
switch
(
lt
(
stop_plus_len
,
0
),
# stop + len < 0
0
,
# stop + len >= 0
stop_plus_len
),
# stop >= 0: use min(stop, length)
switch
(
lt
(
stop
,
length
),
stop
,
length
))
return
slice
(
0
,
stop
,
1
),
1
elif
is_stop_length
:
# start:length:1
if
is_start_constant
and
start
>=
0
:
return
slice
(
switch
(
lt
(
start
,
length
),
start
,
length
),
length
,
1
),
1
start_plus_len
=
start
+
length
start
=
switch
(
lt
(
start
,
0
),
# start < 0
switch
(
lt
(
start_plus_len
,
0
),
# start + len < 0
0
,
# start + len >= 0
start_plus_len
),
# start >= 0: use min(start, length)
switch
(
lt
(
start
,
length
),
start
,
length
))
return
slice
(
start
,
length
,
1
),
1
# This is the generic case.
if
is_step_constant
:
# When we know the sign of `step`, the graph can be made simpler.
assert
step
!=
0
if
step
>
0
:
def
switch_neg_step
(
a
,
b
):
return
b
abs_step
=
step
sgn_step
=
1
else
:
def
switch_neg_step
(
a
,
b
):
return
a
abs_step
=
-
step
sgn_step
=
-
1
else
:
is_step_neg
=
lt
(
step
,
0
)
def
switch_neg_step
(
a
,
b
):
return
switch
(
is_step_neg
,
a
,
b
)
abs_step
=
abs
(
step
)
sgn_step
=
sgn
(
step
)
defstart
=
switch_neg_step
(
length
-
1
,
0
)
defstop
=
switch_neg_step
(
-
1
,
length
)
if
start
is
None
:
if
start
is
None
:
start
=
defstart
start
=
defstart
else
:
else
:
start
=
switch
(
lt
(
start
,
0
),
start
+
length
,
start
)
start
=
switch
(
lt
(
start
,
0
),
start
+
length
,
start
)
start
=
switch
(
lt
(
start
,
0
),
switch
(
lt
(
step
,
0
),
-
1
,
0
),
start
)
start
=
switch
(
lt
(
start
,
0
),
switch
_neg_step
(
-
1
,
0
),
start
)
start
=
switch
(
ge
(
start
,
length
),
start
=
switch
(
ge
(
start
,
length
),
switch
(
lt
(
step
,
0
),
(
length
-
1
)
,
length
),
switch
_neg_step
(
length
-
1
,
length
),
start
)
start
)
if
stop
in
[
None
,
maxsize
]:
if
stop
in
[
None
,
maxsize
]:
# The special "maxsize" case is probably not needed here,
# The special "maxsize" case is probably not needed here,
...
@@ -4282,18 +4369,20 @@ def get_canonical_form_slice(theslice, length):
...
@@ -4282,18 +4369,20 @@ def get_canonical_form_slice(theslice, length):
stop
=
switch
(
lt
(
stop
,
0
),
-
1
,
stop
)
stop
=
switch
(
lt
(
stop
,
0
),
-
1
,
stop
)
stop
=
switch
(
ge
(
stop
,
length
),
length
,
stop
)
stop
=
switch
(
ge
(
stop
,
length
),
length
,
stop
)
nw_stop
=
switch
(
lt
(
step
,
0
),
(
start
+
1
)
,
stop
)
nw_stop
=
switch
_neg_step
(
start
+
1
,
stop
)
slice_len
=
(
start
-
stop
-
1
)
//
abs
(
step
)
+
1
slice_len
=
(
start
-
stop
-
1
)
//
abs
_step
+
1
slice_len
=
switch
(
lt
(
slice_len
,
0
),
0
,
slice_len
)
slice_len
=
switch
(
lt
(
slice_len
,
0
),
0
,
slice_len
)
neg_start
=
nw_stop
-
(
slice_len
-
1
)
*
abs
(
step
)
-
1
neg_start
=
nw_stop
-
(
slice_len
-
1
)
*
abs
_step
-
1
neg_start
=
switch
(
lt
(
neg_start
,
0
),
(
nw_stop
-
1
),
neg_start
)
neg_start
=
switch
(
lt
(
neg_start
,
0
),
(
nw_stop
-
1
),
neg_start
)
nw_start
=
switch
(
lt
(
step
,
0
),
neg_start
,
start
)
nw_start
=
switch
_neg_step
(
neg_start
,
start
)
nw_start
=
switch
(
lt
(
nw_start
,
0
),
0
,
nw_start
)
nw_start
=
switch
(
lt
(
nw_start
,
0
),
0
,
nw_start
)
nw_stop
=
switch
(
lt
(
nw_stop
,
0
),
0
,
nw_stop
)
nw_stop
=
switch
(
lt
(
nw_stop
,
0
),
0
,
nw_stop
)
# Ensure start <= stop.
nw_start
=
switch
(
lt
(
nw_start
,
nw_stop
),
nw_start
,
nw_stop
)
nw_step
=
abs
(
step
)
nw_step
=
abs
_step
if
step
!=
1
:
if
step
!=
1
:
reverse
=
sgn
(
step
)
reverse
=
sgn
_step
return
slice
(
nw_start
,
nw_stop
,
nw_step
),
reverse
return
slice
(
nw_start
,
nw_stop
,
nw_step
),
reverse
else
:
else
:
return
slice
(
nw_start
,
nw_stop
,
nw_step
),
1
return
slice
(
nw_start
,
nw_stop
,
nw_step
),
1
...
@@ -4554,10 +4643,11 @@ class Subtensor(Op):
...
@@ -4554,10 +4643,11 @@ class Subtensor(Op):
and
(
idx
.
step
is
None
or
idx
.
step
==
1
)):
and
(
idx
.
step
is
None
or
idx
.
step
==
1
)):
outshp
.
append
(
xl
)
outshp
.
append
(
xl
)
else
:
else
:
cnf
=
get_canonical_form_slice
(
idx
,
xl
)
cnf
=
get_canonical_form_slice
(
idx
,
xl
)[
0
]
length
=
((
cnf
[
0
]
.
stop
-
cnf
[
0
]
.
start
-
1
)
//
cnf
[
0
]
.
step
if
cnf
.
step
==
1
:
+
1
)
length
=
cnf
.
stop
-
cnf
.
start
length
=
switch
(
lt
(
length
,
0
),
0
,
length
)
else
:
length
=
(
cnf
.
stop
-
cnf
.
start
-
1
)
//
cnf
.
step
+
1
outshp
.
append
(
length
)
outshp
.
append
(
length
)
i
+=
1
i
+=
1
else
:
else
:
...
@@ -7031,7 +7121,8 @@ class AdvancedSubtensor1(Op):
...
@@ -7031,7 +7121,8 @@ class AdvancedSubtensor1(Op):
// if all values fit.
// if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP)) {
if (!PyArray_CanCastSafely(i_type, NPY_INTP)) {
npy_int64 min_val, max_val;
npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min(
%(i_name)
s, NPY_MAXDIMS, NULL);
PyObject* py_min_val = PyArray_Min(
%(i_name)
s, NPY_MAXDIMS,
NULL);
if (py_min_val == NULL) {
if (py_min_val == NULL) {
%(fail)
s;
%(fail)
s;
}
}
...
@@ -7040,7 +7131,8 @@ class AdvancedSubtensor1(Op):
...
@@ -7040,7 +7131,8 @@ class AdvancedSubtensor1(Op):
if (min_val == -1 && PyErr_Occurred()) {
if (min_val == -1 && PyErr_Occurred()) {
%(fail)
s;
%(fail)
s;
}
}
PyObject* py_max_val = PyArray_Max(
%(i_name)
s, NPY_MAXDIMS, NULL);
PyObject* py_max_val = PyArray_Max(
%(i_name)
s, NPY_MAXDIMS,
NULL);
if (py_max_val == NULL) {
if (py_max_val == NULL) {
%(fail)
s;
%(fail)
s;
}
}
...
@@ -7050,7 +7142,8 @@ class AdvancedSubtensor1(Op):
...
@@ -7050,7 +7142,8 @@ class AdvancedSubtensor1(Op):
%(fail)
s;
%(fail)
s;
}
}
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
PyErr_SetString(PyExc_IndexError, "Index contains values "
PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array "
"that are bigger than the maximum array "
"size on this system.");
"size on this system.");
%(fail)
s;
%(fail)
s;
...
@@ -7081,7 +7174,8 @@ class AdvancedSubtensor1(Op):
...
@@ -7081,7 +7174,8 @@ class AdvancedSubtensor1(Op):
}
}
if (
%(output_name)
s != NULL) {
if (
%(output_name)
s != NULL) {
for (; i < nd; i++) {
for (; i < nd; i++) {
if (shape[i] != PyArray_DIMS(
%(a_name)
s)[i-PyArray_NDIM(indices)+1]) {
if (shape[i] != PyArray_DIMS(
%(a_name)
s)[
i-PyArray_NDIM(indices)+1]) {
Py_CLEAR(
%(output_name)
s);
Py_CLEAR(
%(output_name)
s);
break;
break;
}
}
...
@@ -7089,8 +7183,8 @@ class AdvancedSubtensor1(Op):
...
@@ -7089,8 +7183,8 @@ class AdvancedSubtensor1(Op):
}
}
}
}
}
}
%(output_name)
s = (PyArrayObject*)PyArray_TakeFrom(
%(a_name)
s, indices, 0,
%(output_name)
s = (PyArrayObject*)PyArray_TakeFrom(
%(output_name)
s, NPY_RAISE);
%(a_name)
s, indices, 0,
%(output_name)
s, NPY_RAISE);
Py_DECREF(indices);
Py_DECREF(indices);
if (
%(output_name)
s == NULL)
%(fail)
s;
if (
%(output_name)
s == NULL)
%(fail)
s;
"""
%
locals
()
"""
%
locals
()
...
@@ -7100,6 +7194,7 @@ class AdvancedSubtensor1(Op):
...
@@ -7100,6 +7194,7 @@ class AdvancedSubtensor1(Op):
advanced_subtensor1
=
AdvancedSubtensor1
()
advanced_subtensor1
=
AdvancedSubtensor1
()
class
AdvancedIncSubtensor1
(
Op
):
class
AdvancedIncSubtensor1
(
Op
):
"""Increments a subtensor using advanced slicing (list of index)"""
"""Increments a subtensor using advanced slicing (list of index)"""
def
__init__
(
self
,
inplace
=
False
,
set_instead_of_inc
=
False
):
def
__init__
(
self
,
inplace
=
False
,
set_instead_of_inc
=
False
):
...
@@ -7163,10 +7258,10 @@ class AdvancedIncSubtensor1(Op):
...
@@ -7163,10 +7258,10 @@ class AdvancedIncSubtensor1(Op):
x
[
idx
]
=
y
x
[
idx
]
=
y
else
:
else
:
increment
=
inplace_increment
increment
=
inplace_increment
if
increment
is
None
:
if
increment
is
None
:
increment
=
self
.
inplace_increment1d_slow
increment
=
self
.
inplace_increment1d_slow
increment
(
x
,
idx
,
y
)
increment
(
x
,
idx
,
y
)
out
[
0
]
=
x
out
[
0
]
=
x
...
@@ -7209,7 +7304,8 @@ class AdvancedIncSubtensor1(Op):
...
@@ -7209,7 +7304,8 @@ class AdvancedIncSubtensor1(Op):
return
[
gx
,
gy
]
+
[
DisconnectedType
()()]
*
len
(
idx_list
)
return
[
gx
,
gy
]
+
[
DisconnectedType
()()]
*
len
(
idx_list
)
advanced_inc_subtensor1
=
AdvancedIncSubtensor1
()
advanced_inc_subtensor1
=
AdvancedIncSubtensor1
()
def
as_index_variable
(
idx
):
def
as_index_variable
(
idx
):
if
idx
is
None
:
if
idx
is
None
:
return
NoneConst
return
NoneConst
...
@@ -7269,6 +7365,7 @@ class SliceType(gof.Type):
...
@@ -7269,6 +7365,7 @@ class SliceType(gof.Type):
slicetype
=
SliceType
()
slicetype
=
SliceType
()
class
NoneTypeT
(
gof
.
Type
):
class
NoneTypeT
(
gof
.
Type
):
def
filter
(
self
,
x
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
x
,
strict
=
False
,
allow_downcast
=
None
):
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
0e35cc21
...
@@ -6632,6 +6632,18 @@ class T_get_scalar_constant_value(unittest.TestCase):
...
@@ -6632,6 +6632,18 @@ class T_get_scalar_constant_value(unittest.TestCase):
for
j
in
range
(
c
.
value
.
shape
[
1
]):
for
j
in
range
(
c
.
value
.
shape
[
1
]):
assert
get_scalar_constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
assert
get_scalar_constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
def
test_numpy_array
(
self
):
# Regression test for crash when called on a numpy array.
assert
get_scalar_constant_value
(
numpy
.
array
(
3
))
==
3
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar_constant_value
,
numpy
.
array
([
0
,
1
]))
self
.
assertRaises
(
tensor
.
EmptyConstantError
,
get_scalar_constant_value
,
numpy
.
array
([]))
class
T_as_tensor_variable
(
unittest
.
TestCase
):
class
T_as_tensor_variable
(
unittest
.
TestCase
):
"""
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论