提交 b4aefd7d authored 作者: Frederic Bastien's avatar Frederic Bastien

move sparse opt to the right file.

上级 9cbb58f4
......@@ -731,23 +731,6 @@ class CSMGrad(gof.op.Op):
csm_grad = CSMGrad
@gof.local_optimizer([csm_properties])
def skip_pack_csc01(node):
"""if we find csm_properties(CSM(*args)), then we can replace that with the
*args directly"""
if node.op == csm_properties:
csm, = node.inputs
if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR):
# csm.owner.inputs could be broadcastable. In that case, we have
# to adjust the broadcasting flag here.
ret_var = [tensor.patternbroadcast(i, o.broadcastable)
for i, o in izip(csm.owner.inputs, node.outputs)]
return ret_var
return False
register_specialize(skip_pack_csc01)
#
# Conversion
#
......
from itertools import izip
import theano
from theano import gof
from theano.sparse import Remove0
from theano.sparse import (CSC, CSR, csm_properties, Remove0,
register_specialize)
@gof.local_optimizer([None])
......@@ -17,3 +20,20 @@ theano.compile.optdb.register('local_inplace_remove0',
gof.TopoOptimizer(local_inplace_remove0,
failure_callback=gof.TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace')
@gof.local_optimizer([csm_properties])
def skip_pack_csc01(node):
"""if we find csm_properties(CSM(*args)), then we can replace that with the
*args directly"""
if node.op == csm_properties:
csm, = node.inputs
if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR):
# csm.owner.inputs could be broadcastable. In that case, we have
# to adjust the broadcasting flag here.
ret_var = [theano.tensor.patternbroadcast(i, o.broadcastable)
for i, o in izip(csm.owner.inputs, node.outputs)]
return ret_var
return False
register_specialize(skip_pack_csc01)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论