提交 0e01edd8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Reorganize and comment TensorFromScalar.grad

上级 b1b1cf6c
......@@ -1921,10 +1921,18 @@ class TensorFromScalar(Op):
def grad(self, inp, grads):
s, = inp
dt, = grads
assert dt.type.dtype.find('float') != -1
if s.type.dtype.find('int') != -1:
if s.type.dtype in float_dtypes:
assert dt.type.dtype in float_dtypes
return [scalar_from_tensor(dt)]
# If the input dtype is an integer, then so is the output dtype,
# and the "zero" gradient can be represented in that int dtype.
# Currently, theano.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(theano.config.floatX)]
return [scalar_from_tensor(dt)]
raise NotImplementedError("grad not implemented for complex dtypes")
def __str__(self):
return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论