提交 3aa8499e authored 作者: James Bergstra's avatar James Bergstra

Used the casting operations in scalar.basic to downcast gradients of inputs that

had been upcast. I did this for add, sub, div. Mul looked tricky so I didn't do it.
上级 ed72a991
......@@ -617,7 +617,7 @@ class Add(ScalarOp):
retval = []
for i in inputs:
if i.type in grad_types:
retval += [gz]
retval += [cast(gz, i.type.dtype)]
else:
retval += [None]
return retval
......@@ -656,15 +656,14 @@ class Sub(BinaryScalarOp):
return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
if x.type in grad_types:
first_part = gz
first_part = cast(gz, x.type.dtype)
else:
first_part = None
first_part = None
if y.type in grad_types:
second_part = -gz
second_part = cast(-gz, y.type.dtype)
else:
second_part = None
second_part = None
return first_part, second_part
sub = Sub(upcast_out, name = 'sub')
......@@ -695,12 +694,12 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
if x.type in grad_types:
first_part = gz / y
first_part = cast(gz / y, x.type.dtype)
else:
first_part = None
if y.type in grad_types:
second_part = -(gz * x) / (y * y)
second_part = cast(-(gz * x) / (y * y), y.type.dtype)
else:
second_part = None
return first_part, second_part
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论