提交 1fca3a15 authored 作者: James Bergstra's avatar James Bergstra

Added support for CudaNdarray Constants

- pickling and unpickling them when they are in graph keys. - adding a Signature type.
上级 bd004ea6
...@@ -241,3 +241,18 @@ class CudaNdarrayType(Type): ...@@ -241,3 +241,18 @@ class CudaNdarrayType(Type):
def c_compiler(self): def c_compiler(self):
return nvcc_module_compile_str return nvcc_module_compile_str
# THIS WORKS
# But CudaNdarray instances don't compare equal to one another, and what about __hash__ ?
# So the unpickled version doesn't equal the pickled version, and the cmodule cache is not
# happy with the situation.
import copy_reg
def CudaNdarray_unpickler(npa):
return cuda_ndarray.CudaNdarray(npa)
copy_reg.constructor(CudaNdarray_unpickler)
def CudaNdarray_pickler(cnda):
return (CudaNdarray_unpickler, (numpy.asarray(cnda),))
copy_reg.pickle(cuda_ndarray.CudaNdarray, CudaNdarray_pickler, CudaNdarray_unpickler)
...@@ -33,8 +33,12 @@ class CudaNdarrayVariable(Variable, _operators): ...@@ -33,8 +33,12 @@ class CudaNdarrayVariable(Variable, _operators):
pass pass
CudaNdarrayType.Variable = CudaNdarrayVariable CudaNdarrayType.Variable = CudaNdarrayVariable
class CudaNdarrayConstant(Constant, _operators): class CudaNdarrayConstantSignature(tensor.TensorConstantSignature):
pass pass
class CudaNdarrayConstant(Constant, _operators):
def signature(self):
return CudaNdarrayConstantSignature((self.type, numpy.asarray(self.data)))
CudaNdarrayType.Constant = CudaNdarrayConstant CudaNdarrayType.Constant = CudaNdarrayConstant
class CudaNdarraySharedVariable(SharedVariable, _operators): class CudaNdarraySharedVariable(SharedVariable, _operators):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论