Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a05ffe0f
提交
a05ffe0f
authored
7月 26, 2013
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Interface change: tensor.basic do not contain *Subtensor* obj.
Also move take to subtensor.py This allow subtensor.py to depend on basic.py. This is a more sensible dependency.
上级
5ab56080
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
39 行增加
和
73 行删除
+39
-73
__init__.py
theano/tensor/__init__.py
+1
-0
basic.py
theano/tensor/basic.py
+11
-48
nnet.py
theano/tensor/nnet/nnet.py
+9
-8
opt.py
theano/tensor/opt.py
+18
-17
没有找到文件。
theano/tensor/__init__.py
浏览文件 @
a05ffe0f
...
@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
...
@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
import
warnings
import
warnings
from
theano.tensor.basic
import
*
from
theano.tensor.basic
import
*
from
theano.tensor.subtensor
import
*
from
theano.tensor.type_other
import
*
from
theano.tensor.type_other
import
*
from
theano.tensor
import
opt
from
theano.tensor
import
opt
...
...
theano/tensor/basic.py
浏览文件 @
a05ffe0f
...
@@ -18,12 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable
...
@@ -18,12 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable
from
theano.tensor
import
elemwise
from
theano.tensor
import
elemwise
from
theano.tensor.type
import
TensorType
from
theano.tensor.type
import
TensorType
from
theano.tensor.subtensor
import
(
AdvancedIndexingError
,
Subtensor
,
IncSubtensor
,
inc_subtensor
,
set_subtensor
,
AdvancedSubtensor1
,
AdvancedIncSubtensor1
,
AdvancedSubtensor
,
AdvancedIncSubtensor
,
advanced_subtensor1
)
from
theano
import
scalar
as
scal
from
theano
import
scalar
as
scal
from
theano.gof.python25
import
partial
,
any
,
all
,
maxsize
from
theano.gof.python25
import
partial
,
any
,
all
,
maxsize
from
theano.gof.utils
import
hashtype
,
MethodNotDefined
from
theano.gof.utils
import
hashtype
,
MethodNotDefined
...
@@ -573,7 +567,7 @@ def get_scalar_constant_value(v):
...
@@ -573,7 +567,7 @@ def get_scalar_constant_value(v):
ret
=
[[
None
]]
ret
=
[[
None
]]
v
.
owner
.
op
.
perform
(
v
.
owner
,
[
const
],
ret
)
v
.
owner
.
op
.
perform
(
v
.
owner
,
[
const
],
ret
)
return
ret
[
0
][
0
]
return
ret
[
0
][
0
]
if
isinstance
(
v
.
owner
.
op
,
Subtensor
)
and
v
.
ndim
==
0
:
if
isinstance
(
v
.
owner
.
op
,
theano
.
tensor
.
subtensor
.
Subtensor
)
and
v
.
ndim
==
0
:
# This condition depends on Subtensor always embedding constant
# This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply
# indices in the Op rather than making them inputs to the Apply
# node.
# node.
...
@@ -1199,8 +1193,8 @@ class _tensor_py_operators:
...
@@ -1199,8 +1193,8 @@ class _tensor_py_operators:
axis
=
None
axis
=
None
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
try
:
try
:
arg
==
numpy
.
newaxis
or
Subtensor
.
convert
(
arg
)
arg
==
numpy
.
newaxis
or
theano
.
tensor
.
subtensor
.
Subtensor
.
convert
(
arg
)
except
AdvancedIndexingError
:
except
theano
.
tensor
.
subtensor
.
AdvancedIndexingError
:
if
advanced
:
if
advanced
:
axis
=
None
axis
=
None
break
break
...
@@ -1220,7 +1214,7 @@ class _tensor_py_operators:
...
@@ -1220,7 +1214,7 @@ class _tensor_py_operators:
theano
.
tensor
.
sharedvar
.
TensorSharedVariable
))):
theano
.
tensor
.
sharedvar
.
TensorSharedVariable
))):
return
self
.
take
(
arg
,
axis
)
return
self
.
take
(
arg
,
axis
)
else
:
else
:
return
AdvancedSubtensor
()(
self
,
*
args
)
return
theano
.
tensor
.
subtensor
.
AdvancedSubtensor
()(
self
,
*
args
)
else
:
else
:
if
numpy
.
newaxis
in
args
:
if
numpy
.
newaxis
in
args
:
# None (aka np.newaxis) in numpy indexing means to add a
# None (aka np.newaxis) in numpy indexing means to add a
...
@@ -1244,11 +1238,12 @@ class _tensor_py_operators:
...
@@ -1244,11 +1238,12 @@ class _tensor_py_operators:
rval
=
view
.
__getitem__
(
tuple
(
new_args
))
rval
=
view
.
__getitem__
(
tuple
(
new_args
))
return
rval
return
rval
else
:
else
:
return
Subtensor
(
args
)(
self
,
*
Subtensor
.
collapse
(
args
,
return
theano
.
tensor
.
subtensor
.
Subtensor
(
args
)(
self
,
*
theano
.
tensor
.
subtensor
.
Subtensor
.
collapse
(
args
,
lambda
entry
:
isinstance
(
entry
,
Variable
)))
lambda
entry
:
isinstance
(
entry
,
Variable
)))
def
take
(
self
,
indices
,
axis
=
None
,
mode
=
'raise'
):
def
take
(
self
,
indices
,
axis
=
None
,
mode
=
'raise'
):
return
take
(
self
,
indices
,
axis
,
mode
)
return
t
heano
.
tensor
.
subtensor
.
t
ake
(
self
,
indices
,
axis
,
mode
)
# COPYING
# COPYING
def
copy
(
self
):
def
copy
(
self
):
...
@@ -3251,9 +3246,9 @@ class Alloc(gof.Op):
...
@@ -3251,9 +3246,9 @@ class Alloc(gof.Op):
return
False
return
False
elif
(
not
isinstance
(
client
[
0
],
basestring
)
elif
(
not
isinstance
(
client
[
0
],
basestring
)
and
isinstance
(
client
[
0
]
.
op
,
(
and
isinstance
(
client
[
0
]
.
op
,
(
IncSubtensor
,
theano
.
tensor
.
subtensor
.
IncSubtensor
,
AdvancedIncSubtensor1
,
theano
.
tensor
.
subtensor
.
AdvancedIncSubtensor1
,
AdvancedIncSubtensor
,
theano
.
tensor
.
subtensor
.
AdvancedIncSubtensor
,
))):
))):
return
False
return
False
return
True
return
True
...
@@ -3828,7 +3823,7 @@ class Split(Op):
...
@@ -3828,7 +3823,7 @@ class Split(Op):
out_shapes
=
[]
out_shapes
=
[]
for
i
in
range
(
self
.
len_splits
):
for
i
in
range
(
self
.
len_splits
):
temp
=
as_tensor_variable
(
shp_x
)
temp
=
as_tensor_variable
(
shp_x
)
temp
=
set_subtensor
(
temp
[
axis
],
splits
[
i
])
temp
=
theano
.
tensor
.
subtensor
.
set_subtensor
(
temp
[
axis
],
splits
[
i
])
temp
=
[
temp
[
i
]
for
i
in
range
(
len
(
shp_x
))]
temp
=
[
temp
[
i
]
for
i
in
range
(
len
(
shp_x
))]
out_shapes
.
append
(
temp
)
out_shapes
.
append
(
temp
)
return
out_shapes
return
out_shapes
...
@@ -5085,38 +5080,6 @@ def inverse_permutation(perm):
...
@@ -5085,38 +5080,6 @@ def inverse_permutation(perm):
inverse
=
True
)
inverse
=
True
)
def
take
(
a
,
indices
,
axis
=
None
,
mode
=
'raise'
):
a
=
as_tensor_variable
(
a
)
indices
=
as_tensor_variable
(
indices
)
# Reuse advanced_subtensor1 if indices is a vector
if
indices
.
ndim
==
1
:
if
mode
==
'clip'
:
indices
=
clip
(
indices
,
0
,
a
.
shape
[
axis
]
-
1
)
elif
mode
==
'wrap'
:
indices
=
indices
%
a
.
shape
[
axis
]
if
axis
is
None
:
return
advanced_subtensor1
(
a
.
flatten
(),
indices
)
elif
axis
==
0
:
return
advanced_subtensor1
(
a
,
indices
)
else
:
if
axis
<
0
:
axis
+=
a
.
ndim
assert
axis
>=
0
shuffle
=
range
(
a
.
ndim
)
shuffle
[
0
]
=
axis
shuffle
[
axis
]
=
0
return
advanced_subtensor1
(
a
.
dimshuffle
(
shuffle
),
indices
)
.
dimshuffle
(
shuffle
)
if
axis
is
None
:
shape
=
indices
.
shape
ndim
=
indices
.
ndim
else
:
shape
=
concatenate
(
[
a
.
shape
[:
axis
],
indices
.
shape
,
a
.
shape
[
axis
+
1
:]])
ndim
=
a
.
ndim
+
indices
.
ndim
-
1
return
take
(
a
,
indices
.
flatten
(),
axis
,
mode
)
.
reshape
(
shape
,
ndim
)
#########################
#########################
# Linalg : Dot
# Linalg : Dot
#########################
#########################
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
a05ffe0f
...
@@ -8,6 +8,7 @@ import numpy
...
@@ -8,6 +8,7 @@ import numpy
import
theano
import
theano
from
theano
import
gof
from
theano
import
gof
from
theano.tensor
import
basic
as
tensor
from
theano.tensor
import
basic
as
tensor
from
theano.tensor
import
subtensor
from
theano.tensor
import
elemwise
,
dmatrix
,
fmatrix
,
dvector
,
fvector
from
theano.tensor
import
elemwise
,
dmatrix
,
fmatrix
,
dvector
,
fvector
from
theano.tensor
import
opt
from
theano.tensor
import
opt
from
theano.compile
import
optdb
from
theano.compile
import
optdb
...
@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
...
@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
# typically we should not need the gradient w.r.t. dy).
# typically we should not need the gradient w.r.t. dy).
y_idx_range
=
tensor
.
arange
(
y_idx
.
shape
[
0
])
y_idx_range
=
tensor
.
arange
(
y_idx
.
shape
[
0
])
g_dy
=
tensor
.
sum
(
g_dy
=
tensor
.
sum
(
g_dx
*
tensor
.
AdvancedIncSubtensor
()(
g_dx
*
sub
tensor
.
AdvancedIncSubtensor
()(
sm
,
tensor
.
fill
(
dy
,
-
1
),
y_idx_range
,
y_idx
),
sm
,
tensor
.
fill
(
dy
,
-
1
),
y_idx_range
,
y_idx
),
axis
=
1
)
axis
=
1
)
g_sm
=
dy
.
dimshuffle
(
0
,
'x'
)
*
g_dx
g_sm
=
dy
.
dimshuffle
(
0
,
'x'
)
*
g_dx
...
@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
...
@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
# Not sure if that case happens any more after the introduction of
# Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if
isinstance
(
stop
.
owner
.
op
,
tensor
.
Subtensor
):
if
isinstance
(
stop
.
owner
.
op
,
sub
tensor
.
Subtensor
):
shape_subtensor
=
stop
.
owner
shape_subtensor
=
stop
.
owner
if
list
(
shape_subtensor
.
op
.
idx_list
)
==
[
0
]:
if
list
(
shape_subtensor
.
op
.
idx_list
)
==
[
0
]:
shape_var
,
=
shape_subtensor
.
inputs
shape_var
,
=
shape_subtensor
.
inputs
...
@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
...
@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
log
=
None
log
=
None
sm
=
None
sm
=
None
# First case: log(softmax(x))[rows, labels]
# First case: log(softmax(x))[rows, labels]
if
isinstance
(
node
.
op
,
tensor
.
AdvancedSubtensor
):
if
isinstance
(
node
.
op
,
sub
tensor
.
AdvancedSubtensor
):
try
:
try
:
log
,
rows
,
labels
=
node
.
inputs
log
,
rows
,
labels
=
node
.
inputs
except
Exception
:
except
Exception
:
...
@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
...
@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
# Second case: log(softmax(x)[rows, labels])
# Second case: log(softmax(x)[rows, labels])
if
node
.
op
==
tensor
.
log
:
if
node
.
op
==
tensor
.
log
:
pre_log
=
node
.
inputs
[
0
]
.
owner
pre_log
=
node
.
inputs
[
0
]
.
owner
if
pre_log
and
isinstance
(
pre_log
.
op
,
tensor
.
AdvancedSubtensor
):
if
pre_log
and
isinstance
(
pre_log
.
op
,
sub
tensor
.
AdvancedSubtensor
):
try
:
try
:
sm
,
rows
,
labels
=
pre_log
.
inputs
sm
,
rows
,
labels
=
pre_log
.
inputs
except
Exception
:
except
Exception
:
...
@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# After the check for AdvancedIncSubtensor, if anything does not fit with
# After the check for AdvancedIncSubtensor, if anything does not fit with
# the formula above, there's no way to fit it with the the second case,
# the formula above, there's no way to fit it with the the second case,
# so we return immediately.
# so we return immediately.
if
d_sm
.
owner
and
isinstance
(
d_sm
.
owner
.
op
,
tensor
.
AdvancedIncSubtensor
):
if
d_sm
.
owner
and
isinstance
(
d_sm
.
owner
.
op
,
sub
tensor
.
AdvancedIncSubtensor
):
try
:
try
:
z
,
incr
,
rows
,
labels
=
d_sm
.
owner
.
inputs
z
,
incr
,
rows
,
labels
=
d_sm
.
owner
.
inputs
except
Exception
:
except
Exception
:
...
@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if
not
denom
.
owner
:
if
not
denom
.
owner
:
return
return
if
isinstance
(
denom
.
owner
.
op
,
tensor
.
AdvancedSubtensor
):
if
isinstance
(
denom
.
owner
.
op
,
sub
tensor
.
AdvancedSubtensor
):
# Base case
# Base case
adv_subtensor
=
denom
adv_subtensor
=
denom
#out_grad /= 1.
#out_grad /= 1.
...
@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# and the output gradient
# and the output gradient
for
i
,
input
in
enumerate
(
denom
.
owner
.
inputs
):
for
i
,
input
in
enumerate
(
denom
.
owner
.
inputs
):
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
tensor
.
AdvancedSubtensor
):
sub
tensor
.
AdvancedSubtensor
):
other_inputs
=
[
in_
for
(
j
,
other_inputs
=
[
in_
for
(
j
,
in_
)
in
enumerate
(
denom
.
owner
.
inputs
)
if
j
!=
i
]
in_
)
in
enumerate
(
denom
.
owner
.
inputs
)
if
j
!=
i
]
if
len
(
other_inputs
)
==
1
:
if
len
(
other_inputs
)
==
1
:
...
@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
...
@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return
return
# Check the numerator (AdvancedIncSubtensor)
# Check the numerator (AdvancedIncSubtensor)
if
num
.
owner
and
isinstance
(
num
.
owner
.
op
,
tensor
.
AdvancedIncSubtensor
):
if
num
.
owner
and
isinstance
(
num
.
owner
.
op
,
sub
tensor
.
AdvancedIncSubtensor
):
try
:
try
:
z
,
incr
,
rows
,
labels
=
num
.
owner
.
inputs
z
,
incr
,
rows
,
labels
=
num
.
owner
.
inputs
except
Exception
:
except
Exception
:
...
...
theano/tensor/opt.py
浏览文件 @
a05ffe0f
...
@@ -24,7 +24,8 @@ from theano.gof.python25 import maxsize
...
@@ -24,7 +24,8 @@ from theano.gof.python25 import maxsize
from
theano.gof.utils
import
MethodNotDefined
from
theano.gof.utils
import
MethodNotDefined
from
theano.configparser
import
config
from
theano.configparser
import
config
from
theano.tensor.elemwise
import
Elemwise
,
DimShuffle
from
theano.tensor.elemwise
import
Elemwise
,
DimShuffle
from
theano.tensor.subtensor
import
get_idx_list
,
get_canonical_form_slice
from
theano.tensor.subtensor
import
(
get_idx_list
,
get_canonical_form_slice
,
Subtensor
,
IncSubtensor
,
AdvancedIncSubtensor1
)
from
theano
import
scalar
from
theano
import
scalar
from
theano.tensor
import
basic
as
T
from
theano.tensor
import
basic
as
T
from
theano
import
compile
# to register the optimizer built by this file
from
theano
import
compile
# to register the optimizer built by this file
...
@@ -1218,13 +1219,13 @@ def local_track_shape_i(node):
...
@@ -1218,13 +1219,13 @@ def local_track_shape_i(node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@gof.local_optimizer
([
T
.
Subtensor
])
@gof.local_optimizer
([
Subtensor
])
def
local_subtensor_make_vector
(
node
):
def
local_subtensor_make_vector
(
node
):
# replace all subtensor(make_vector) like:
# replace all subtensor(make_vector) like:
# [a,b,c][0] -> a
# [a,b,c][0] -> a
# [a,b,c][0:2] -> [a,b]
# [a,b,c][0:2] -> [a,b]
# we can do this for constant indexes
# we can do this for constant indexes
if
isinstance
(
node
.
op
,
T
.
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
# This optimization needs ShapeOpt and fgraph.shape_feature
# This optimization needs ShapeOpt and fgraph.shape_feature
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
if
x
.
owner
and
x
.
owner
.
op
==
make_vector
:
if
x
.
owner
and
x
.
owner
.
op
==
make_vector
:
...
@@ -1592,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node):
...
@@ -1592,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
Subtensor
])
@gof.local_optimizer
([
Subtensor
])
def
local_useless_subtensor
(
node
):
def
local_useless_subtensor
(
node
):
"""
"""
Remove Subtensor if it takes the full input
Remove Subtensor if it takes the full input
"""
"""
if
isinstance
(
node
.
op
,
T
.
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
# This optimization needs ShapeOpt and fgraph.shape_feature
# This optimization needs ShapeOpt and fgraph.shape_feature
if
not
hasattr
(
node
.
fgraph
,
'shape_feature'
):
if
not
hasattr
(
node
.
fgraph
,
'shape_feature'
):
return
return
...
@@ -1678,7 +1679,7 @@ def local_subtensor_lift(node):
...
@@ -1678,7 +1679,7 @@ def local_subtensor_lift(node):
when x,... are broadcasted scalar or not broadcasted at all
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
rebroadcast(x)[idx] => rebroadcast(x[idx])
"""
"""
if
isinstance
(
node
.
op
,
T
.
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
u
=
node
.
inputs
[
0
]
if
not
u
.
owner
or
len
(
u
.
clients
)
>
1
:
if
not
u
.
owner
or
len
(
u
.
clients
)
>
1
:
return
False
return
False
...
@@ -1737,7 +1738,7 @@ def local_subtensor_lift(node):
...
@@ -1737,7 +1738,7 @@ def local_subtensor_lift(node):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
j
+=
1
j
+=
1
subt_x
=
T
.
Subtensor
(
node
.
op
.
idx_list
)(
u
.
owner
.
inputs
[
0
])
subt_x
=
Subtensor
(
node
.
op
.
idx_list
)(
u
.
owner
.
inputs
[
0
])
rbcast_subt_x
=
T
.
Rebroadcast
(
*
new_axis
)(
subt_x
)
rbcast_subt_x
=
T
.
Rebroadcast
(
*
new_axis
)(
subt_x
)
return
[
rbcast_subt_x
]
return
[
rbcast_subt_x
]
...
@@ -1886,9 +1887,9 @@ def local_subtensor_merge(node):
...
@@ -1886,9 +1887,9 @@ def local_subtensor_merge(node):
expresses all slices in a canonical form, and then merges them together.
expresses all slices in a canonical form, and then merges them together.
"""
"""
if
isinstance
(
node
.
op
,
T
.
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
u
=
node
.
inputs
[
0
]
if
u
.
owner
and
isinstance
(
u
.
owner
.
op
,
T
.
Subtensor
):
if
u
.
owner
and
isinstance
(
u
.
owner
.
op
,
Subtensor
):
# We can merge :)
# We can merge :)
# x actual tensor on which we are picking slices
# x actual tensor on which we are picking slices
x
=
u
.
owner
.
inputs
[
0
]
x
=
u
.
owner
.
inputs
[
0
]
...
@@ -1928,8 +1929,8 @@ def local_subtensor_merge(node):
...
@@ -1928,8 +1929,8 @@ def local_subtensor_merge(node):
else
:
else
:
merged_slices
+=
slices1
[
pos_1
:]
merged_slices
+=
slices1
[
pos_1
:]
subtens
=
T
.
Subtensor
(
merged_slices
)
subtens
=
Subtensor
(
merged_slices
)
sl_ins
=
T
.
Subtensor
.
collapse
(
sl_ins
=
Subtensor
.
collapse
(
merged_slices
,
merged_slices
,
lambda
x
:
isinstance
(
x
,
T
.
Variable
))
lambda
x
:
isinstance
(
x
,
T
.
Variable
))
out
=
subtens
.
make_node
(
x
,
*
sl_ins
)
.
outputs
[
0
]
out
=
subtens
.
make_node
(
x
,
*
sl_ins
)
.
outputs
[
0
]
...
@@ -1942,7 +1943,7 @@ def local_subtensor_merge(node):
...
@@ -1942,7 +1943,7 @@ def local_subtensor_merge(node):
@gof.local_optimizer
([])
@gof.local_optimizer
([])
def
local_subtensor_of_alloc
(
node
):
def
local_subtensor_of_alloc
(
node
):
"""alloc[x:y] -> alloc"""
"""alloc[x:y] -> alloc"""
if
not
isinstance
(
node
.
op
,
T
.
Subtensor
):
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
return
False
u
=
node
.
inputs
[
0
]
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
:
if
u
.
owner
is
None
:
...
@@ -2027,7 +2028,7 @@ def local_IncSubtensor_serialize(node):
...
@@ -2027,7 +2028,7 @@ def local_IncSubtensor_serialize(node):
def
movable
(
i
):
def
movable
(
i
):
# Return True iff this is a incsubtensor that we can move
# Return True iff this is a incsubtensor that we can move
return
i
.
owner
\
return
i
.
owner
\
and
isinstance
(
i
.
owner
.
op
,
T
.
IncSubtensor
)
\
and
isinstance
(
i
.
owner
.
op
,
IncSubtensor
)
\
and
i
.
type
==
o_type
\
and
i
.
type
==
o_type
\
and
len
(
i
.
clients
)
==
1
\
and
len
(
i
.
clients
)
==
1
\
and
not
i
.
owner
.
op
.
set_instead_of_inc
and
not
i
.
owner
.
op
.
set_instead_of_inc
...
@@ -2061,7 +2062,7 @@ def local_inplace_setsubtensor(node):
...
@@ -2061,7 +2062,7 @@ def local_inplace_setsubtensor(node):
"""
"""
Also work for GpuIncSubtensor
Also work for GpuIncSubtensor
"""
"""
if
isinstance
(
node
.
op
,
T
.
IncSubtensor
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
__class__
(
new_op
=
node
.
op
.
__class__
(
node
.
op
.
idx_list
,
inplace
=
True
,
node
.
op
.
idx_list
,
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
,
...
@@ -2078,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor',
...
@@ -2078,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor',
@gof.local_optimizer
([
None
])
@gof.local_optimizer
([
None
])
def
local_inplace_incsubtensor1
(
node
):
def
local_inplace_incsubtensor1
(
node
):
""" also work for GpuAdvancedIncSubtensor1 """
""" also work for GpuAdvancedIncSubtensor1 """
if
isinstance
(
node
.
op
,
T
.
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
__class__
(
new_op
=
node
.
op
.
__class__
(
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
)
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
)
new_node
=
new_op
(
*
node
.
inputs
)
new_node
=
new_op
(
*
node
.
inputs
)
...
@@ -2098,7 +2099,7 @@ def local_incsubtensor_of_allocs(node):
...
@@ -2098,7 +2099,7 @@ def local_incsubtensor_of_allocs(node):
"""
"""
IncSubtensor(x, zeros, idx) -> x
IncSubtensor(x, zeros, idx) -> x
"""
"""
if
isinstance
(
node
.
op
,
T
.
IncSubtensor
)
and
not
node
.
op
.
set_instead_of_inc
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
y
=
node
.
inputs
[
1
]
replace
=
False
replace
=
False
...
@@ -2123,7 +2124,7 @@ def local_setsubtensor_of_allocs(node):
...
@@ -2123,7 +2124,7 @@ def local_setsubtensor_of_allocs(node):
when x is constant or alloc.
when x is constant or alloc.
"""
"""
if
isinstance
(
node
.
op
,
T
.
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
y
=
node
.
inputs
[
1
]
replace_x
=
None
replace_x
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论