提交 d844c86b authored 作者: James Bergstra's avatar James Bergstra

cleaned up casting and dtype handling in mean and dot

上级 70b773a5
...@@ -1417,15 +1417,19 @@ def mean(input, axis = None): ...@@ -1417,15 +1417,19 @@ def mean(input, axis = None):
if str(input.dtype).startswith('int'): if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps # we need to cast eventually anyway, and this helps
# to prevents overflow # to prevents overflow
input = convert_to_float64(input) input = cast(input, 'float64')
s = sum(input, axis) s = sum(input, axis)
shp = shape(input) shp = shape(input)
if input.dtype == 'float32':
shp = cast(shp, 'float32')
if axis is None: if axis is None:
axis = range(input.type.ndim) axis = range(input.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
s = s / shp[i] s = s / shp[i]
if input.dtype.startswith('float'):
assert input.dtype == s.dtype
return s return s
@constructor @constructor
...@@ -2543,12 +2547,15 @@ class Dot(Op): ...@@ -2543,12 +2547,15 @@ class Dot(Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0: if gz.type.ndim == 0:
return gz * y, gz * x rval = gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1: elif x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz) rval = dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1: elif x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz) rval = outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz) else:
rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论