提交 6ee3d69f authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added checks to split

上级 aff1bd89
...@@ -1807,6 +1807,12 @@ class Split(Op): ...@@ -1807,6 +1807,12 @@ class Split(Op):
if axis.type not in int_types: if axis.type not in int_types:
raise TypeError('axis must have type lscalar', axis.type) raise TypeError('axis must have type lscalar', axis.type)
# # The following lines are necessary if we allow splits of zero
# if isinstance(axis, gof.Constant):
# x = unbroadcast(x, int(axis.data))
# else:
# x = unbroadcast(x, *range(x.type.ndim))
inputs = [x, axis, splits] inputs = [x, axis, splits]
outputs = [x.type() for i in xrange(self.len_splits)] outputs = [x.type() for i in xrange(self.len_splits)]
...@@ -1823,6 +1829,11 @@ class Split(Op): ...@@ -1823,6 +1829,11 @@ class Split(Op):
if len(splits) != self.len_splits: if len(splits) != self.len_splits:
raise ValueError('In Split.perform(), len(splits) != len_splits.', raise ValueError('In Split.perform(), len(splits) != len_splits.',
(len(splits), self.len_splits)) (len(splits), self.len_splits))
if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' % (numpy.sum(splits), len_along_axis))
if not all(splits):
raise ValueError('Cannot have a split of zero.')
# Checking is done, let's roll the splitting algorithm! # Checking is done, let's roll the splitting algorithm!
# Basically we step along the given axis of x, extracting subtensors of size splits[i] # Basically we step along the given axis of x, extracting subtensors of size splits[i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论