提交 6f2690dd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove apply_shape, and tag.shape of variables. Closes #633.

上级 360377fd
"""Apply for use with Tensors that implements shape propagation via variable.tag.shape
This is not used currently very used. It appear in some case, but I'm not sure it if work or if it is used by default.
It could help the current system to make it detect problem earlier when contructing the graph instead of during optimization.
"""
import sys
from theano import gof
def ishape(v):
try:
return (True, v.tag.shape)
except AttributeError:
return (False, (None,)*v.type.ndim)
class Apply(gof.Apply):
def __init__(self, op, inputs, outputs):
super(Apply, self).__init__(op, inputs, outputs)
if not inputs:
return
# if any input has any shape info, then propagate it
try:
provided, ishapes = zip(*[ishape(i) for i in inputs])
except AttributeError:
# i.type.ndim didn't make sense for some i
return
if provided == [False for i in inputs]:
# no input had a tag.shape
return
try:
infer_shape = op.infer_shape
except AttributeError:
# op has no infer_shape, that's fine
return
try:
oshapes = infer_shape(self, ishapes)
except NotImplementedError:
return
for o, oshp in zip(outputs, oshapes):
o.tag.shape = oshp
...@@ -12,8 +12,7 @@ import numpy, theano ...@@ -12,8 +12,7 @@ import numpy, theano
#from copy import copy as python_copy #from copy import copy as python_copy
from theano import gof, shared from theano import gof, shared
from theano.gof import Variable, Op, Type, Constant, Value from theano.gof import Apply, Constant, Op, Type, Value, Variable
from theano.gof.apply_shape import Apply
from theano import gradient from theano import gradient
...@@ -287,7 +286,6 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -287,7 +286,6 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
TensorType(dtype = x_.dtype, broadcastable = bcastable), TensorType(dtype = x_.dtype, broadcastable = bcastable),
x_.copy(), x_.copy(),
name=name) name=name)
rval.tag.shape = x_.shape
return rval return rval
else: else:
# leave the shape out of the type # leave the shape out of the type
...@@ -3501,25 +3499,12 @@ class Join(Op): ...@@ -3501,25 +3499,12 @@ class Join(Op):
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
# Join op should get at least two inputs to join # ishapes[0] contains the size of the axis on which we join
# Join op should get at least one input to join
assert len(ishapes) > 1 assert len(ishapes) > 1
# Not sure this is needed anymore :( ... basically the apply_shape
# version of the apply node (i.e. the one defined in
# gof/apply_shape) calls infer_shape methods passing None to unknown
# inputs. It can handle NotImplementedError, so for now I just raise
# that whenever I get a None. Should we just remove gof/apply_shape
# if it is depricated ??
if ishapes[1] is None:
raise NotImplementedError
n_dim = len(ishapes[1]) n_dim = len(ishapes[1])
for shape in ishapes[1:]: for shape in ishapes[1:]:
if shape is None: assert shape is not None
raise NotImplementedError
for shape_i in shape:
if shape_i is None:
raise NotImplementedError
# at this point the inputs have been broadcasted so they should
# all have the same shape
assert len(shape) == n_dim assert len(shape) == n_dim
out_shapes = [] out_shapes = []
...@@ -3837,9 +3822,6 @@ def reshape(x, newshape, ndim=None, name=None): ...@@ -3837,9 +3822,6 @@ def reshape(x, newshape, ndim=None, name=None):
ndim = get_vector_length(newshape) ndim = get_vector_length(newshape)
op = Reshape(ndim, name) op = Reshape(ndim, name)
rval = op(x, newshape) rval = op(x, newshape)
if isinstance(newshape, (list, tuple)):
rval.tag.shape = newshape
return rval return rval
class Flatten(Op): class Flatten(Op):
......
...@@ -6,16 +6,13 @@ import numpy.distutils ...@@ -6,16 +6,13 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer) InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer, Apply)
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
import basic as T import basic as T
from theano.gof.apply_shape import Apply
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
from theano import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
......
...@@ -5,12 +5,11 @@ import numpy ...@@ -5,12 +5,11 @@ import numpy
import elemwise_cgen as cgen import elemwise_cgen as cgen
import theano import theano
from theano import gof from theano import gof
from theano.gof import Op from theano.gof import Apply, Op
from theano import scalar from theano import scalar
from theano.scalar import Scalar from theano.scalar import Scalar
from theano.printing import pprint from theano.printing import pprint
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.gof.apply_shape import Apply
# tensor depends on elemwise to provide definitions for several ops # tensor depends on elemwise to provide definitions for several ops
......
...@@ -18,7 +18,7 @@ import theano ...@@ -18,7 +18,7 @@ import theano
from theano.tensor import (as_tensor_variable, blas, get_constant_value, from theano.tensor import (as_tensor_variable, blas, get_constant_value,
patternbroadcast) patternbroadcast)
from theano import Op, config from theano import Op, config
from theano.gof.apply_shape import Apply from theano.gof import Apply
from theano.gof.python25 import any from theano.gof.python25 import any
imported_scipy_signal = False imported_scipy_signal = False
......
...@@ -11,7 +11,7 @@ from theano.tensor import basic as tensor ...@@ -11,7 +11,7 @@ from theano.tensor import basic as tensor
from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector
from theano.tensor import opt from theano.tensor import opt
from theano.compile import optdb from theano.compile import optdb
from theano.gof.apply_shape import Apply from theano.gof import Apply
from theano.tensor.nnet.sigm import sigmoid, softplus from theano.tensor.nnet.sigm import sigmoid, softplus
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论