提交 799834f6 authored 作者: Frederic Bastien's avatar Frederic Bastien

reallow ndarray to be passed to Shape.make_node

上级 3a3424c2
...@@ -1444,11 +1444,13 @@ class Shape(Op): ...@@ -1444,11 +1444,13 @@ class Shape(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
if not isinstance(x, Variable):
raise TypeError('x must be Variable whose value have a shape attribute', x)
#Must work for all type that have a shape attribute. #Must work for all type that have a shape attribute.
#This will fail at execution time. #This will fail at execution time.
#x = as_tensor_variable(x) x = as_tensor_variable(x)
#Each type variable should implement their .shape attribute
#and have the fct infer_shape() implemented in the op that convert
#the type to TensorVariable to have the optimization working
#correctly.
return Apply(self, [x], [lvector()]) return Apply(self, [x], [lvector()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
out[0] = theano._asarray(x.shape, dtype = 'int64') out[0] = theano._asarray(x.shape, dtype = 'int64')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论