提交 4e1a8f78 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add validate_shape to tensor.Alloc

上级 149263a9
...@@ -2422,14 +2422,9 @@ class Alloc(gof.Op): ...@@ -2422,14 +2422,9 @@ class Alloc(gof.Op):
""" """
__props__ = () __props__ = ()
def make_node(self, value, *shape): def validate_shape(self, shape):
v = as_tensor_variable(value)
sh = [as_tensor_variable(s) for s in shape] sh = [as_tensor_variable(s) for s in shape]
bcast = [] bcast = []
if v.ndim > len(sh):
raise TypeError("The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim, len(sh))
for i, s in enumerate(sh): for i, s in enumerate(sh):
if s.type.dtype[:3] not in ('int', 'uin'): if s.type.dtype[:3] not in ('int', 'uin'):
if config.exception_verbosity == 'high': if config.exception_verbosity == 'high':
...@@ -2445,8 +2440,17 @@ class Alloc(gof.Op): ...@@ -2445,8 +2440,17 @@ class Alloc(gof.Op):
except NotScalarConstantError: except NotScalarConstantError:
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
return sh, bcast
def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = self.validate_shape(shape)
if v.ndim > len(sh):
raise TypeError("The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim, len(sh))
otype = TensorType(dtype=v.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, ([v] + sh), [otype()]) return gof.Apply(self, [v] + sh, [otype()])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
out, = out_ out, = out_
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论