提交 4e1a8f78 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add validate_shape to tensor.Alloc

上级 149263a9
......@@ -2422,14 +2422,9 @@ class Alloc(gof.Op):
"""
__props__ = ()
def make_node(self, value, *shape):
v = as_tensor_variable(value)
def validate_shape(self, shape):
sh = [as_tensor_variable(s) for s in shape]
bcast = []
if v.ndim > len(sh):
raise TypeError("The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim, len(sh))
for i, s in enumerate(sh):
if s.type.dtype[:3] not in ('int', 'uin'):
if config.exception_verbosity == 'high':
......@@ -2445,8 +2440,17 @@ class Alloc(gof.Op):
except NotScalarConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
return sh, bcast
def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = self.validate_shape(shape)
if v.ndim > len(sh):
raise TypeError("The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim, len(sh))
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, ([v] + sh), [otype()])
return gof.Apply(self, [v] + sh, [otype()])
def perform(self, node, inputs, out_):
out, = out_
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论