提交 b63f0dee authored 作者: notoraptor's avatar notoraptor

Check old pickled nodes for GpuMagmaSVD.

上级 5d82e9e6
......@@ -407,6 +407,15 @@ class GpuMagmaSVD(COp):
[GpuArrayType(A.dtype, broadcastable=[False],
context_name=ctx_name)()])
def prepare_node(self, node, storage_map, compute_map, impl):
# Check node to prevent eventual errors with old pickled nodes.
if self.compute_uv:
A, B, C = node.outputs
# We expect order: S (vector), U (matrix), VT (matrix)
assert A.type.ndim == 1 and B.type.ndim == C.type.ndim == 2, \
"Due to implementation constraints, GpuMagmaSVD interface has changed and now returns (S, U, VT) " \
"instead of (U, S, VT). Either update your code, or use gpu_svd() to get the expected (U, S, VT) order."
def get_params(self, node):
return self.params_type.get_params(self, context=node.inputs[0].type.context)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论