提交 5520a058 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

correction of grad method for DimShuffle

上级 40012bce
......@@ -390,8 +390,11 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
return [DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz))]
if 'int' in inp[0].dtype:
return [DisconnectedType()()]
else:
return [DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz))]
class DimShufflePrinter:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论