提交 428e4f8c authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

use imported function get_idx_list consistently instead of sometimes…

use imported function get_idx_list consistently instead of sometimes theano.tensor.subtensor.get_idx_list
上级 4d8536b5
......@@ -830,8 +830,7 @@ class ShapeFeature(object):
# The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function
# that will return the right index.
idx = theano.tensor.subtensor.get_idx_list(s_i.owner.inputs,
s_i.owner.op.idx_list)
idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list)
assert len(idx) == 1
idx = idx[0]
try:
......@@ -1865,8 +1864,7 @@ def local_useless_inc_subtensor(node):
# Check that we keep all the original data.
# Put the constant inputs in the slice.
idx_cst = theano.tensor.subtensor.get_idx_list(node.inputs[1:],
node.op.idx_list)
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and
e.stop is None and (e.step is None or T.extract_constant(e.step) == -1)
for e in idx_cst):
......@@ -2358,7 +2356,7 @@ def local_subtensor_of_dot(node):
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
idx_list = theano.tensor.subtensor.get_idx_list(node.inputs, node.op.idx_list)
idx_list = get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论