提交 a712caa2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check for SparseType by class name, since theano.sparse is not imported.

上级 8dc58fbb
...@@ -114,11 +114,11 @@ class OutputGuard(gof.Op): ...@@ -114,11 +114,11 @@ class OutputGuard(gof.Op):
return """ return """
%(z)s = %(x)s; %(z)s = %(x)s;
""" % locals() """ % locals()
elif isinstance(node.inputs[0].type, elif (isinstance(node.inputs[0].type,
(theano.tensor.TensorType, (theano.tensor.TensorType,
theano.sandbox.cuda.CudaNdarrayType, theano.sandbox.cuda.CudaNdarrayType,
theano.sparse.SparseType, theano.tensor.raw_random.RandomStateType)) or
theano.tensor.raw_random.RandomStateType) node.inputs[0].type.__class__.__name__ == 'SparseType'
): ):
# These are Python object types # These are Python object types
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论