提交 3525d5d0 authored 作者: James Bergstra's avatar James Bergstra

fixed bug in split grad, added var() to tensor.basic

上级 405feb22
......@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
import __builtin__
import sys # for sys.maxint
import traceback #for overriding Op.__call__
import functools
import numpy
......@@ -1196,7 +1197,12 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def mean(input, axis = None):
"""WRITEME"""
"""Compute the mean value along the given axis of a tensor `input`
:param axis: compute the mean along this axis of the tensor. None means trailing axis.
:type axis: None or int or (list of int) (see `Sum`)
"""
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
......@@ -1211,6 +1217,43 @@ def mean(input, axis = None):
s = s / shp[i]
return s
@constructor
def var(input, axis = None):
"""Compute the variance along the given axis of a tensor `input`
:param axis: compute the variance along this axis of the tensor. None means trailing axis.
:type axis: None or int or (list of int) (see `Sum`)
"""
input_ndim = input.type.ndim
if axis == None:
axis = range(input_ndim)
if isinstance(axis, int):
axis = [axis]
#make a pattern that will undo the reduction of dimensions caused by mean
pattern = []
next_dim = 0
for i in range(input_ndim):
if i in axis:
pattern.append('x')
else:
pattern.append(next_dim)
next_dim += 1
#compute the axis-wise mean
mean_input_reduced = mean(input, axis)
#broadcast that back out to match input
mean_input = DimShuffle(
list(mean_input_reduced.type.broadcastable),
pattern)(mean_input_reduced)
#center the input
centered_input = input - mean_input
#return the mean sqr
return mean(centered_input**2, axis)
class Repeat(gof.Op):
......@@ -1571,6 +1614,14 @@ class Split(Op):
def __hash__(self):
return hash(Split) ^ self.len_splits
def __call__(self, *inputs, **kwargs):
"""Override Op.__call__ to suppress unpacking of output list
"""
node = self.make_node(*inputs, **kwargs)
node.tag.trace = traceback.extract_stack()[:-1]
return node.outputs
def make_node(self, x, axis, splits):
"""WRITEME"""
x = as_tensor_variable(x)
......@@ -1696,9 +1747,10 @@ class Join(Op):
"""
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype:
# assume that this isn't differentiable
# assume that this is differentiable
split = Split(len(tensors))
return [None] + split(gz, axis, stack(*[shape(x)[axis] for x in tensors]))
split_gz = split(gz, axis, stack(*[shape(x)[axis] for x in tensors]))
return [None] + split_gz
else:
# assume that this isn't differentiable
return [None] * (1 + len(tensors))
......
......@@ -1859,6 +1859,23 @@ def test_reshape_member_fn():
y = x.reshape((4,5,6))
assert y.owner.op == Reshape(3)
def test_var():
a = Tensor(dtype='float64', broadcastable=[False,False,False])()
f = function([a], var(a))
a_val = numpy.arange(60).reshape(3,4,5)
print numpy.var(a_val)
print f(a_val)
assert numpy.allclose(numpy.var(a_val), f(a_val))
f = function([a], var(a, axis=0))
assert numpy.allclose(numpy.var(a_val, axis=0), f(a_val))
f = function([a], var(a, axis=1))
assert numpy.allclose(numpy.var(a_val, axis=1), f(a_val))
f = function([a], var(a, axis=2))
assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val))
if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论