提交 bb5a3b56 authored 作者: Colin Raffel's avatar Colin Raffel

Explicltly casing instead of type checking

上级 be56eff4
...@@ -3231,7 +3231,7 @@ gpu_join = GpuJoin() ...@@ -3231,7 +3231,7 @@ gpu_join = GpuJoin()
class GpuSplit(tensor.Split, GpuOp): class GpuSplit(tensor.Split, GpuOp):
def make_node(self, x, axis, splits): def make_node(self, x, axis, splits):
assert isinstance(x.type, CudaNdarrayType) x = as_cuda_ndarray_variable(x)
node = tensor.Split.make_node(self, x, axis, splits) node = tensor.Split.make_node(self, x, axis, splits)
outs = [CudaNdarrayType(dtype=o.dtype, outs = [CudaNdarrayType(dtype=o.dtype,
broadcastable=o.type.broadcastable)() broadcastable=o.type.broadcastable)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论