提交 2e8d00e9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added zero() and one()

上级 1b4cc22d
...@@ -23,7 +23,6 @@ from elemwise import Elemwise, DimShuffle, CAReduce, Sum ...@@ -23,7 +23,6 @@ from elemwise import Elemwise, DimShuffle, CAReduce, Sum
import tensor_random as random import tensor_random as random
def as_tensor(x, name = None): def as_tensor(x, name = None):
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
...@@ -619,6 +618,7 @@ class Filler(gof.Op): ...@@ -619,6 +618,7 @@ class Filler(gof.Op):
broadcastable = (False,)*ndim) broadcastable = (False,)*ndim)
def make_node(self, dims): def make_node(self, dims):
dims = as_tensor(dims)
return gof.Apply(self, [dims], [self.type()]) return gof.Apply(self, [dims], [self.type()])
def perform(self, node, (dims,), (out,)): def perform(self, node, (dims,), (out,)):
...@@ -645,6 +645,11 @@ class Filler(gof.Op): ...@@ -645,6 +645,11 @@ class Filler(gof.Op):
Zeros = functools.partial(Filler, 0) Zeros = functools.partial(Filler, 0)
Ones = functools.partial(Filler, 1) Ones = functools.partial(Filler, 1)
def zero():
return Zeros(0)([])
def one():
return Ones(0)([])
tensor_copy = elemwise.Elemwise(scal.identity) tensor_copy = elemwise.Elemwise(scal.identity)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论