提交 d844c86b authored 作者: James Bergstra's avatar James Bergstra

cleaned up casting and dtype handling in mean and dot

上级 70b773a5
......@@ -1417,15 +1417,19 @@ def mean(input, axis = None):
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
input = convert_to_float64(input)
input = cast(input, 'float64')
s = sum(input, axis)
shp = shape(input)
if input.dtype == 'float32':
shp = cast(shp, 'float32')
if axis is None:
axis = range(input.type.ndim)
elif isinstance(axis, int):
axis = [axis]
for i in axis:
s = s / shp[i]
if input.dtype.startswith('float'):
assert input.dtype == s.dtype
return s
@constructor
......@@ -2543,12 +2547,15 @@ class Dot(Op):
def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0:
return gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz)
rval = gz * y, gz * x
elif x.type.ndim == 1 and y.type.ndim > 1:
rval = dot(gz, y.T), outer(x.T, gz)
elif x.type.ndim > 1 and y.type.ndim == 1:
rval = outer(gz, y.T), dot(x.T, gz)
else:
rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def __str__(self):
return "dot"
dot = Dot()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论