提交 a05ffe0f authored 作者: Frederic's avatar Frederic

Interface change: tensor.basic do not contain *Subtensor* obj.

Also move take to subtensor.py This allow subtensor.py to depend on basic.py. This is a more sensible dependency.
上级 5ab56080
...@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en" ...@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
import warnings import warnings
from theano.tensor.basic import * from theano.tensor.basic import *
from theano.tensor.subtensor import *
from theano.tensor.type_other import * from theano.tensor.type_other import *
from theano.tensor import opt from theano.tensor import opt
......
...@@ -18,12 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable ...@@ -18,12 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano.tensor.subtensor import (AdvancedIndexingError,
Subtensor, IncSubtensor,
inc_subtensor, set_subtensor,
AdvancedSubtensor1, AdvancedIncSubtensor1,
AdvancedSubtensor, AdvancedIncSubtensor,
advanced_subtensor1)
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all, maxsize from theano.gof.python25 import partial, any, all, maxsize
from theano.gof.utils import hashtype, MethodNotDefined from theano.gof.utils import hashtype, MethodNotDefined
...@@ -573,7 +567,7 @@ def get_scalar_constant_value(v): ...@@ -573,7 +567,7 @@ def get_scalar_constant_value(v):
ret = [[None]] ret = [[None]]
v.owner.op.perform(v.owner, [const], ret) v.owner.op.perform(v.owner, [const], ret)
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim == 0: if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant # This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply # indices in the Op rather than making them inputs to the Apply
# node. # node.
...@@ -1199,8 +1193,8 @@ class _tensor_py_operators: ...@@ -1199,8 +1193,8 @@ class _tensor_py_operators:
axis = None axis = None
for i, arg in enumerate(args): for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or theano.tensor.subtensor.Subtensor.convert(arg)
except AdvancedIndexingError: except theano.tensor.subtensor.AdvancedIndexingError:
if advanced: if advanced:
axis = None axis = None
break break
...@@ -1220,7 +1214,7 @@ class _tensor_py_operators: ...@@ -1220,7 +1214,7 @@ class _tensor_py_operators:
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis) return self.take(arg, axis)
else: else:
return AdvancedSubtensor()(self, *args) return theano.tensor.subtensor.AdvancedSubtensor()(self, *args)
else: else:
if numpy.newaxis in args: if numpy.newaxis in args:
# None (aka np.newaxis) in numpy indexing means to add a # None (aka np.newaxis) in numpy indexing means to add a
...@@ -1244,11 +1238,12 @@ class _tensor_py_operators: ...@@ -1244,11 +1238,12 @@ class _tensor_py_operators:
rval = view.__getitem__(tuple(new_args)) rval = view.__getitem__(tuple(new_args))
return rval return rval
else: else:
return Subtensor(args)(self, *Subtensor.collapse(args, return theano.tensor.subtensor.Subtensor(args)(
self, *theano.tensor.subtensor.Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable))) lambda entry: isinstance(entry, Variable)))
def take(self, indices, axis=None, mode='raise'): def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode) return theano.tensor.subtensor.take(self, indices, axis, mode)
# COPYING # COPYING
def copy(self): def copy(self):
...@@ -3251,9 +3246,9 @@ class Alloc(gof.Op): ...@@ -3251,9 +3246,9 @@ class Alloc(gof.Op):
return False return False
elif (not isinstance(client[0], basestring) elif (not isinstance(client[0], basestring)
and isinstance(client[0].op, ( and isinstance(client[0].op, (
IncSubtensor, theano.tensor.subtensor.IncSubtensor,
AdvancedIncSubtensor1, theano.tensor.subtensor.AdvancedIncSubtensor1,
AdvancedIncSubtensor, theano.tensor.subtensor.AdvancedIncSubtensor,
))): ))):
return False return False
return True return True
...@@ -3828,7 +3823,7 @@ class Split(Op): ...@@ -3828,7 +3823,7 @@ class Split(Op):
out_shapes = [] out_shapes = []
for i in range(self.len_splits): for i in range(self.len_splits):
temp = as_tensor_variable(shp_x) temp = as_tensor_variable(shp_x)
temp = set_subtensor(temp[axis], splits[i]) temp = theano.tensor.subtensor.set_subtensor(temp[axis], splits[i])
temp = [temp[i] for i in range(len(shp_x))] temp = [temp[i] for i in range(len(shp_x))]
out_shapes.append(temp) out_shapes.append(temp)
return out_shapes return out_shapes
...@@ -5085,38 +5080,6 @@ def inverse_permutation(perm): ...@@ -5085,38 +5080,6 @@ def inverse_permutation(perm):
inverse=True) inverse=True)
def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1:
if mode == 'clip':
indices = clip(indices, 0, a.shape[axis] - 1)
elif mode == 'wrap':
indices = indices % a.shape[axis]
if axis is None:
return advanced_subtensor1(a.flatten(), indices)
elif axis == 0:
return advanced_subtensor1(a, indices)
else:
if axis < 0:
axis += a.ndim
assert axis >= 0
shuffle = range(a.ndim)
shuffle[0] = axis
shuffle[axis] = 0
return advanced_subtensor1(
a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
if axis is None:
shape = indices.shape
ndim = indices.ndim
else:
shape = concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
......
...@@ -8,6 +8,7 @@ import numpy ...@@ -8,6 +8,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.tensor import subtensor
from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector
from theano.tensor import opt from theano.tensor import opt
from theano.compile import optdb from theano.compile import optdb
...@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
# typically we should not need the gradient w.r.t. dy). # typically we should not need the gradient w.r.t. dy).
y_idx_range = tensor.arange(y_idx.shape[0]) y_idx_range = tensor.arange(y_idx.shape[0])
g_dy = tensor.sum( g_dy = tensor.sum(
g_dx * tensor.AdvancedIncSubtensor()( g_dx * subtensor.AdvancedIncSubtensor()(
sm, tensor.fill(dy, -1), y_idx_range, y_idx), sm, tensor.fill(dy, -1), y_idx_range, y_idx),
axis=1) axis=1)
g_sm = dy.dimshuffle(0, 'x') * g_dx g_sm = dy.dimshuffle(0, 'x') * g_dx
...@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
# Not sure if that case happens any more after the introduction of # Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, tensor.Subtensor): if isinstance(stop.owner.op, subtensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if list(shape_subtensor.op.idx_list) == [0]: if list(shape_subtensor.op.idx_list) == [0]:
shape_var, = shape_subtensor.inputs shape_var, = shape_subtensor.inputs
...@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
log = None log = None
sm = None sm = None
# First case: log(softmax(x))[rows, labels] # First case: log(softmax(x))[rows, labels]
if isinstance(node.op, tensor.AdvancedSubtensor): if isinstance(node.op, subtensor.AdvancedSubtensor):
try: try:
log, rows, labels = node.inputs log, rows, labels = node.inputs
except Exception: except Exception:
...@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
# Second case: log(softmax(x)[rows, labels]) # Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log: if node.op == tensor.log:
pre_log = node.inputs[0].owner pre_log = node.inputs[0].owner
if pre_log and isinstance(pre_log.op, tensor.AdvancedSubtensor): if pre_log and isinstance(pre_log.op, subtensor.AdvancedSubtensor):
try: try:
sm, rows, labels = pre_log.inputs sm, rows, labels = pre_log.inputs
except Exception: except Exception:
...@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# After the check for AdvancedIncSubtensor, if anything does not fit with # After the check for AdvancedIncSubtensor, if anything does not fit with
# the formula above, there's no way to fit it with the the second case, # the formula above, there's no way to fit it with the the second case,
# so we return immediately. # so we return immediately.
if d_sm.owner and isinstance(d_sm.owner.op, tensor.AdvancedIncSubtensor): if d_sm.owner and isinstance(d_sm.owner.op, subtensor.AdvancedIncSubtensor):
try: try:
z, incr, rows, labels = d_sm.owner.inputs z, incr, rows, labels = d_sm.owner.inputs
except Exception: except Exception:
...@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if not denom.owner: if not denom.owner:
return return
if isinstance(denom.owner.op, tensor.AdvancedSubtensor): if isinstance(denom.owner.op, subtensor.AdvancedSubtensor):
# Base case # Base case
adv_subtensor = denom adv_subtensor = denom
#out_grad /= 1. #out_grad /= 1.
...@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# and the output gradient # and the output gradient
for i, input in enumerate(denom.owner.inputs): for i, input in enumerate(denom.owner.inputs):
if input.owner and isinstance(input.owner.op, if input.owner and isinstance(input.owner.op,
tensor.AdvancedSubtensor): subtensor.AdvancedSubtensor):
other_inputs = [in_ for (j, other_inputs = [in_ for (j,
in_) in enumerate(denom.owner.inputs) if j != i] in_) in enumerate(denom.owner.inputs) if j != i]
if len(other_inputs) == 1: if len(other_inputs) == 1:
...@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return return
# Check the numerator (AdvancedIncSubtensor) # Check the numerator (AdvancedIncSubtensor)
if num.owner and isinstance(num.owner.op, tensor.AdvancedIncSubtensor): if num.owner and isinstance(num.owner.op, subtensor.AdvancedIncSubtensor):
try: try:
z, incr, rows, labels = num.owner.inputs z, incr, rows, labels = num.owner.inputs
except Exception: except Exception:
......
...@@ -24,7 +24,8 @@ from theano.gof.python25 import maxsize ...@@ -24,7 +24,8 @@ from theano.gof.python25 import maxsize
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.configparser import config from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import get_idx_list, get_canonical_form_slice from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, AdvancedIncSubtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1218,13 +1219,13 @@ def local_track_shape_i(node): ...@@ -1218,13 +1219,13 @@ def local_track_shape_i(node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like: # replace all subtensor(make_vector) like:
# [a,b,c][0] -> a # [a,b,c][0] -> a
# [a,b,c][0:2] -> [a,b] # [a,b,c][0:2] -> [a,b]
# we can do this for constant indexes # we can do this for constant indexes
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
x = node.inputs[0] x = node.inputs[0]
if x.owner and x.owner.op == make_vector: if x.owner and x.owner.op == make_vector:
...@@ -1592,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1592,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([Subtensor])
def local_useless_subtensor(node): def local_useless_subtensor(node):
""" """
Remove Subtensor if it takes the full input Remove Subtensor if it takes the full input
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
...@@ -1678,7 +1679,7 @@ def local_subtensor_lift(node): ...@@ -1678,7 +1679,7 @@ def local_subtensor_lift(node):
when x,... are broadcasted scalar or not broadcasted at all when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx]) rebroadcast(x)[idx] => rebroadcast(x[idx])
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
if not u.owner or len(u.clients) > 1: if not u.owner or len(u.clients) > 1:
return False return False
...@@ -1737,7 +1738,7 @@ def local_subtensor_lift(node): ...@@ -1737,7 +1738,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])] new_axis += [(j, u.broadcastable[i])]
j += 1 j += 1
subt_x = T.Subtensor(node.op.idx_list)(u.owner.inputs[0]) subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x] return [rbcast_subt_x]
...@@ -1886,9 +1887,9 @@ def local_subtensor_merge(node): ...@@ -1886,9 +1887,9 @@ def local_subtensor_merge(node):
expresses all slices in a canonical form, and then merges them together. expresses all slices in a canonical form, and then merges them together.
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
if u.owner and isinstance(u.owner.op, T.Subtensor): if u.owner and isinstance(u.owner.op, Subtensor):
# We can merge :) # We can merge :)
# x actual tensor on which we are picking slices # x actual tensor on which we are picking slices
x = u.owner.inputs[0] x = u.owner.inputs[0]
...@@ -1928,8 +1929,8 @@ def local_subtensor_merge(node): ...@@ -1928,8 +1929,8 @@ def local_subtensor_merge(node):
else: else:
merged_slices += slices1[pos_1:] merged_slices += slices1[pos_1:]
subtens = T.Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0] out = subtens.make_node(x, *sl_ins).outputs[0]
...@@ -1942,7 +1943,7 @@ def local_subtensor_merge(node): ...@@ -1942,7 +1943,7 @@ def local_subtensor_merge(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_subtensor_of_alloc(node): def local_subtensor_of_alloc(node):
"""alloc[x:y] -> alloc""" """alloc[x:y] -> alloc"""
if not isinstance(node.op, T.Subtensor): if not isinstance(node.op, Subtensor):
return False return False
u = node.inputs[0] u = node.inputs[0]
if u.owner is None: if u.owner is None:
...@@ -2027,7 +2028,7 @@ def local_IncSubtensor_serialize(node): ...@@ -2027,7 +2028,7 @@ def local_IncSubtensor_serialize(node):
def movable(i): def movable(i):
# Return True iff this is a incsubtensor that we can move # Return True iff this is a incsubtensor that we can move
return i.owner \ return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \ and isinstance(i.owner.op, IncSubtensor) \
and i.type == o_type \ and i.type == o_type \
and len(i.clients) == 1 \ and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc and not i.owner.op.set_instead_of_inc
...@@ -2061,7 +2062,7 @@ def local_inplace_setsubtensor(node): ...@@ -2061,7 +2062,7 @@ def local_inplace_setsubtensor(node):
""" """
Also work for GpuIncSubtensor Also work for GpuIncSubtensor
""" """
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
node.op.idx_list, inplace=True, node.op.idx_list, inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc, set_instead_of_inc=node.op.set_instead_of_inc,
...@@ -2078,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor', ...@@ -2078,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_incsubtensor1(node): def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """ """ also work for GpuAdvancedIncSubtensor1 """
if isinstance(node.op, T.AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
inplace=True, set_instead_of_inc=node.op.set_instead_of_inc) inplace=True, set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
...@@ -2098,7 +2099,7 @@ def local_incsubtensor_of_allocs(node): ...@@ -2098,7 +2099,7 @@ def local_incsubtensor_of_allocs(node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
""" """
if isinstance(node.op, T.IncSubtensor) and not node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and not node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
...@@ -2123,7 +2124,7 @@ def local_setsubtensor_of_allocs(node): ...@@ -2123,7 +2124,7 @@ def local_setsubtensor_of_allocs(node):
when x is constant or alloc. when x is constant or alloc.
""" """
if isinstance(node.op, T.IncSubtensor) and node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace_x = None replace_x = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论