提交 00ed78d2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2511 from sisp/get_idx_list

Inconsistent function call of imported get_idx_list
...@@ -830,8 +830,7 @@ class ShapeFeature(object): ...@@ -830,8 +830,7 @@ class ShapeFeature(object):
# The current Subtensor always put constant index in the graph. # The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function # This was not True in the past. So call the Subtensor function
# that will return the right index. # that will return the right index.
idx = theano.tensor.subtensor.get_idx_list(s_i.owner.inputs, idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list)
s_i.owner.op.idx_list)
assert len(idx) == 1 assert len(idx) == 1
idx = idx[0] idx = idx[0]
try: try:
...@@ -1865,8 +1864,7 @@ def local_useless_inc_subtensor(node): ...@@ -1865,8 +1864,7 @@ def local_useless_inc_subtensor(node):
# Check that we keep all the original data. # Check that we keep all the original data.
# Put the constant inputs in the slice. # Put the constant inputs in the slice.
idx_cst = theano.tensor.subtensor.get_idx_list(node.inputs[1:], idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and if all(isinstance(e, slice) and e.start is None and
e.stop is None and (e.step is None or T.extract_constant(e.step) == -1) e.stop is None and (e.step is None or T.extract_constant(e.step) == -1)
for e in idx_cst): for e in idx_cst):
...@@ -2358,7 +2356,7 @@ def local_subtensor_of_dot(node): ...@@ -2358,7 +2356,7 @@ def local_subtensor_of_dot(node):
a = node.inputs[0].owner.inputs[0] a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1] b = node.inputs[0].owner.inputs[1]
idx_list = theano.tensor.subtensor.get_idx_list(node.inputs, node.op.idx_list) idx_list = get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list)) num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices] a_indices = idx_list[:num_a_indices]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论