提交 29adf634 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

Zeros -> filler

上级 a75410d8
"""A L{Result} to store L{numpy.ndarray} with basic accompanying L{Op}s"""
import sys # for sys.maxint
import inspect
import functools
import numpy
......@@ -605,12 +606,15 @@ tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
fill, fill_inplace = _elemwise(scal.second, 'fill')
def ones_like(model):
return fill(model, 1.0)
return Ones(model.type.ndim)(shape(model))
#return fill(model, 1.0)
def zeros_like(model):
return fill(model, 0.0)
return Zeros(model.type.ndim)(shape(model))
#return fill(model, 0.0)
class Zeros(gof.Op):
def __init__(self, ndim, dtype = 'float64'):
class Filler(gof.Op):
def __init__(self, value, ndim, dtype = 'float64'):
self.value = value
self.ndim = ndim
self.dtype = dtype
self.type = Tensor(dtype = dtype,
......@@ -622,9 +626,14 @@ class Zeros(gof.Op):
def perform(self, node, (dims,), (out,)):
if out[0] is not None:
out[0].resize(dims, refcheck = 0)
out[0].fill(0)
out[0].fill(self.value)
else:
out[0] = numpy.zeros(dims, dtype = self.dtype)
if self.value == 0:
out[0] = numpy.zeros(dims, dtype = self.dtype)
elif self.value == 1:
out[0] = numpy.ones(dims, dtype = self.dtype)
else:
out[0] = numpy.ones(dims, dtype = self.dtype) * self.value
def grad(self, (dims,), (gout,)):
return None,
......@@ -635,6 +644,9 @@ class Zeros(gof.Op):
def __hash__(self):
return hash(self.ndim) ^ hash(self.dtype)
Zeros = functools.partial(Filler, 0)
Ones = functools.partial(Filler, 1)
tensor_copy = elemwise.Elemwise(scal.identity)
......@@ -655,6 +667,26 @@ def mean(input, axis = None):
return s
class Repeat(gof.Op):
def make_node(self, input, repeats, axis):
assert isinstance(input.type, Tensor)
assert repeats.type == iscalar
assert axis.type == iscalar
type = Tensor(dtype = input.type.dtype,
broadcastable = [False if i==axis else x for i, x in enumerate(input.broadcastable)])
return gof.Apply(self, [inputs, repeats, axis], [type()])
def perform(self, node, (input, repeats, axis), (out, )):
out[0] = numpy.repeat(input, repeats, axis)
def grad(self, (input, repeats, axis), (gout, )):
return add.grad((input, gout), (gout,))[:1]
repeat = Repeat()
##########################
# Arithmetics
##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论