提交 a37ccad2 authored 作者: Frederic's avatar Frederic

fix import since move of *subtensor*

上级 a05ffe0f
...@@ -437,8 +437,8 @@ acceptable_ops = (theano.tensor.basic.Dot, ...@@ -437,8 +437,8 @@ acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Shape, theano.tensor.basic.Shape,
theano.tensor.basic.SpecifyShape, theano.tensor.basic.SpecifyShape,
theano.tensor.basic.MaxAndArgmax, theano.tensor.basic.MaxAndArgmax,
theano.tensor.basic.Subtensor, theano.tensor.Subtensor,
theano.tensor.basic.IncSubtensor, theano.tensor.IncSubtensor,
theano.tensor.basic.Rebroadcast, theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc, theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
......
...@@ -691,7 +691,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -691,7 +691,7 @@ class ScanSaveMem(gof.Optimizer):
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values #=> output needs all its intermediate values
elif not isinstance(cl.op, tensor.basic.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -699,7 +699,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -699,7 +699,7 @@ class ScanSaveMem(gof.Optimizer):
#=> output might need to store just a subset of its values #=> output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.basic.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
...@@ -719,7 +719,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -719,7 +719,7 @@ class ScanSaveMem(gof.Optimizer):
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.get_canonical_form_slice(
this_slice[0], length) this_slice[0], length)
slices[i] += [(cf_slice, this_slice)] slices[i] += [(cf_slice, this_slice)]
...@@ -795,11 +795,11 @@ class ScanSaveMem(gof.Optimizer): ...@@ -795,11 +795,11 @@ class ScanSaveMem(gof.Optimizer):
if type(cl) == str: if type(cl) == str:
store_steps[i] = 0 store_steps[i] = 0
break break
elif not isinstance(cl.op, tensor.basic.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
store_steps[i] = 0 store_steps[i] = 0
break break
else: else:
this_slice = tensor.basic.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
store_steps[i] = 0 store_steps[i] = 0
...@@ -817,7 +817,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -817,7 +817,7 @@ class ScanSaveMem(gof.Optimizer):
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.get_canonical_form_slice(
this_slice[0], length) this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
...@@ -973,9 +973,9 @@ class ScanSaveMem(gof.Optimizer): ...@@ -973,9 +973,9 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
# slice inputs # slice inputs
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
...@@ -1014,8 +1014,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1014,8 +1014,8 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (sanitize(position),) + \ nw_slice = (sanitize(position),) + \
tuple(old_slices[1:]) tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
......
...@@ -2034,7 +2034,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -2034,7 +2034,7 @@ class test_local_subtensor_merge(unittest.TestCase):
val = fun(data) val = fun(data)
assert numpy.all(val == data[3:6, 2:6, 1:7][1]) assert numpy.all(val == data[3:6, 2:6, 1:7][1])
assert len([n for n in fun.maker.fgraph.toposort() assert len([n for n in fun.maker.fgraph.toposort()
if isinstance(n.op, tensor.basic.Subtensor)]) == nops if isinstance(n.op, Subtensor)]) == nops
# test 2) # test 2)
y = x[2, 3][1] y = x[2, 3][1]
...@@ -2042,7 +2042,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -2042,7 +2042,7 @@ class test_local_subtensor_merge(unittest.TestCase):
val = fun(data) val = fun(data)
assert numpy.all(val == data[2, 3][1]) assert numpy.all(val == data[2, 3][1])
assert len([n for n in fun.maker.fgraph.toposort() assert len([n for n in fun.maker.fgraph.toposort()
if isinstance(n.op, tensor.basic.Subtensor)]) == nops if isinstance(n.op, Subtensor)]) == nops
# test 3) # test 3)
y = x[3:6, 2, 1:7][1] y = x[3:6, 2, 1:7][1]
...@@ -2050,7 +2050,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -2050,7 +2050,7 @@ class test_local_subtensor_merge(unittest.TestCase):
val = fun(data) val = fun(data)
assert numpy.all(val == data[3:6, 2, 1:7][1]) assert numpy.all(val == data[3:6, 2, 1:7][1])
assert len([n for n in fun.maker.fgraph.toposort() assert len([n for n in fun.maker.fgraph.toposort()
if isinstance(n.op, tensor.basic.Subtensor)]) == nops if isinstance(n.op, Subtensor)]) == nops
def test_scalar6(self): def test_scalar6(self):
# General case with one slice and one index # General case with one slice and one index
......
import logging
_logger = logging.getLogger("theano.tensor.type")
import numpy import numpy
import theano import theano
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论