提交 2deb19cb authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Some more ShapeError fixes, and moved ShapeError to tensor.basic to avoid import issues

上级 35ff6f6a
......@@ -36,7 +36,7 @@ class Apply(gof.Apply):
try:
oshapes = infer_shape(self, ishapes)
except theano.tensor.opt.ShapeError:
except theano.tensor.ShapeError:
return
for o, oshp in zip(outputs, oshapes):
......
......@@ -45,6 +45,11 @@ int_dtypes = map(str, scal.int_types)
discrete_dtypes = map(str, scal.discrete_types)
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
def check_equal_numpy(x, y):
"""
Returns True iff x and y are equal (checks the dtype and
......@@ -3576,14 +3581,14 @@ class Join(Op):
# that whenever I get a None. Should we just remove gof/apply_shape
# if it is depricated ??
if ishapes[1] is None:
raise NotImplementedError
raise ShapeError()
n_dim = len(ishapes[1])
for shape in ishapes[1:]:
if shape is None:
raise NotImplementedError
raise ShapeError()
for shape_i in shape:
if shape_i is None:
raise NotImplementedError
raise ShapeError()
# at this point the inputs have been broadcasted so they should
# all have the same shape
assert len(shape) == n_dim
......
......@@ -575,14 +575,14 @@ class ConvOp(Op):
try:
fmshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx,self.dy), self.out_mode)
except TypeError:
raise theano.tensor.opt.ShapeError()
raise theano.tensor.ShapeError()
outshp = (batch_size,fmo) + tuple(fmshp)
return [outshp]
else:
# Haven't implemented this case. imshp and kshp may be symbollic
# and ConvOp.getOutputShape doesn't handle this. In this case
# we simply let the default function do its work.
raise theano.tensor.opt.ShapeError()
raise theano.tensor.ShapeError()
def perform(self,node, inp, out):
......
......@@ -27,17 +27,12 @@ from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value
from basic import get_constant_value, ShapeError
# Utilities
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
def out2in(*local_opts):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
......@@ -727,7 +722,7 @@ class ShapeFeature(object):
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.opt.ShapeError '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception, e:
_logger.error('Failed to infer_shape from Op %s (i_shapes=%s): %s %s'% (node.op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论