提交 e14f7a33 authored 作者: Sina Honari's avatar Sina Honari

removing a bug

上级 91de659f
......@@ -834,22 +834,22 @@ class Cast(gof.op.Op):
def grad(self, inputs, outputs_gradients):
gz = outputs_gradients[0]
if gz.dtype in complex_types:
if gz.dtype in complex_dtypes:
raise NotImplementedError("grad not implemented for complex types")
if inputs[0].dtype in complex_types:
if inputs[0].dtype in complex_dtypes:
raise NotImplementedError("grad not implemented for complex types")
if gz.dtype in tensor.continuous_dtypes:
if inputs[0].dtype in tensor.continuous_dtypes:
return [Cast(inputs[0].dtype)(gz)]
if gz.dtype in discrete_dtypes:
if inputs[0].dtype in discrete_dtypes:
return [inputs[0].zeros_like(dtype=theano.config.floatX)]
else:
return [gz]
else:
if inputs[0].dtype in tensor.continuous_dtypes:
return [inputs[0].zeros_like()]
else:
if inputs[0].dtype in discrete_dtypes:
return [gz]
else:
return [inputs[0].zeros_like(dtype=theano.config.floatX)]
return [Cast(inputs[0].dtype)(gz)]
def infer_shape(self, node, ins_shapes):
return ins_shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论