提交 2deb19cb authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Some more ShapeError fixes, and moved ShapeError to tensor.basic to avoid import issues

上级 35ff6f6a
...@@ -36,7 +36,7 @@ class Apply(gof.Apply): ...@@ -36,7 +36,7 @@ class Apply(gof.Apply):
try: try:
oshapes = infer_shape(self, ishapes) oshapes = infer_shape(self, ishapes)
except theano.tensor.opt.ShapeError: except theano.tensor.ShapeError:
return return
for o, oshp in zip(outputs, oshapes): for o, oshp in zip(outputs, oshapes):
......
...@@ -45,6 +45,11 @@ int_dtypes = map(str, scal.int_types) ...@@ -45,6 +45,11 @@ int_dtypes = map(str, scal.int_types)
discrete_dtypes = map(str, scal.discrete_types) discrete_dtypes = map(str, scal.discrete_types)
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
def check_equal_numpy(x, y): def check_equal_numpy(x, y):
""" """
Returns True iff x and y are equal (checks the dtype and Returns True iff x and y are equal (checks the dtype and
...@@ -3576,14 +3581,14 @@ class Join(Op): ...@@ -3576,14 +3581,14 @@ class Join(Op):
# that whenever I get a None. Should we just remove gof/apply_shape # that whenever I get a None. Should we just remove gof/apply_shape
# if it is depricated ?? # if it is depricated ??
if ishapes[1] is None: if ishapes[1] is None:
raise NotImplementedError raise ShapeError()
n_dim = len(ishapes[1]) n_dim = len(ishapes[1])
for shape in ishapes[1:]: for shape in ishapes[1:]:
if shape is None: if shape is None:
raise NotImplementedError raise ShapeError()
for shape_i in shape: for shape_i in shape:
if shape_i is None: if shape_i is None:
raise NotImplementedError raise ShapeError()
# at this point the inputs have been broadcasted so they should # at this point the inputs have been broadcasted so they should
# all have the same shape # all have the same shape
assert len(shape) == n_dim assert len(shape) == n_dim
......
...@@ -575,14 +575,14 @@ class ConvOp(Op): ...@@ -575,14 +575,14 @@ class ConvOp(Op):
try: try:
fmshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx,self.dy), self.out_mode) fmshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx,self.dy), self.out_mode)
except TypeError: except TypeError:
raise theano.tensor.opt.ShapeError() raise theano.tensor.ShapeError()
outshp = (batch_size,fmo) + tuple(fmshp) outshp = (batch_size,fmo) + tuple(fmshp)
return [outshp] return [outshp]
else: else:
# Haven't implemented this case. imshp and kshp may be symbollic # Haven't implemented this case. imshp and kshp may be symbollic
# and ConvOp.getOutputShape doesn't handle this. In this case # and ConvOp.getOutputShape doesn't handle this. In this case
# we simply let the default function do its work. # we simply let the default function do its work.
raise theano.tensor.opt.ShapeError() raise theano.tensor.ShapeError()
def perform(self,node, inp, out): def perform(self,node, inp, out):
......
...@@ -27,17 +27,12 @@ from theano import compile #to register the optimizer built by this file ...@@ -27,17 +27,12 @@ from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value from basic import get_constant_value, ShapeError
# Utilities # Utilities
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
def out2in(*local_opts): def out2in(*local_opts):
"""WRITEME """ """WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
...@@ -727,7 +722,7 @@ class ShapeFeature(object): ...@@ -727,7 +722,7 @@ class ShapeFeature(object):
'Code called by infer_shape failed raising a ' 'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to ' 'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer ' 'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.opt.ShapeError ' 'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e) 'instead. The original exception message is: %s' % e)
except Exception, e: except Exception, e:
_logger.error('Failed to infer_shape from Op %s (i_shapes=%s): %s %s'% (node.op, _logger.error('Failed to infer_shape from Op %s (i_shapes=%s): %s %s'% (node.op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论