提交 3aa8499e authored 作者: James Bergstra's avatar James Bergstra

Used the casting operations in scalar.basic to downcast gradients of inputs that

had been upcast. I did this for add, sub, div. Mul looked tricky so I didn't do it.
上级 ed72a991
...@@ -617,7 +617,7 @@ class Add(ScalarOp): ...@@ -617,7 +617,7 @@ class Add(ScalarOp):
retval = [] retval = []
for i in inputs: for i in inputs:
if i.type in grad_types: if i.type in grad_types:
retval += [gz] retval += [cast(gz, i.type.dtype)]
else: else:
retval += [None] retval += [None]
return retval return retval
...@@ -656,15 +656,14 @@ class Sub(BinaryScalarOp): ...@@ -656,15 +656,14 @@ class Sub(BinaryScalarOp):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if x.type in grad_types:
first_part = gz first_part = cast(gz, x.type.dtype)
else: else:
first_part = None first_part = None
if y.type in grad_types: if y.type in grad_types:
second_part = -gz second_part = cast(-gz, y.type.dtype)
else: else:
second_part = None second_part = None
return first_part, second_part return first_part, second_part
sub = Sub(upcast_out, name = 'sub') sub = Sub(upcast_out, name = 'sub')
...@@ -695,12 +694,12 @@ class TrueDiv(BinaryScalarOp): ...@@ -695,12 +694,12 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if x.type in grad_types:
first_part = gz / y first_part = cast(gz / y, x.type.dtype)
else: else:
first_part = None first_part = None
if y.type in grad_types: if y.type in grad_types:
second_part = -(gz * x) / (y * y) second_part = cast(-(gz * x) / (y * y), y.type.dtype)
else: else:
second_part = None second_part = None
return first_part, second_part return first_part, second_part
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论