提交 0e01edd8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Reorganize and comment TensorFromScalar.grad

上级 b1b1cf6c
...@@ -1921,10 +1921,18 @@ class TensorFromScalar(Op): ...@@ -1921,10 +1921,18 @@ class TensorFromScalar(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
s, = inp s, = inp
dt, = grads dt, = grads
assert dt.type.dtype.find('float') != -1 if s.type.dtype in float_dtypes:
if s.type.dtype.find('int') != -1: assert dt.type.dtype in float_dtypes
return [scalar_from_tensor(dt)]
# If the input dtype is an integer, then so is the output dtype,
# and the "zero" gradient can be represented in that int dtype.
# Currently, theano.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(theano.config.floatX)] return [s.zeros_like().astype(theano.config.floatX)]
return [scalar_from_tensor(dt)]
raise NotImplementedError("grad not implemented for complex dtypes")
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论