提交 10b7cb91 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the op Rebroadcaste work for the gpu. Miss the opt that wil move the…

Make the op Rebroadcaste work for the gpu. Miss the opt that wil move the HostFromGpu after this op.
上级 3dd07e6b
...@@ -43,7 +43,9 @@ class CudaNdarrayType(Type): ...@@ -43,7 +43,9 @@ class CudaNdarrayType(Type):
A cyclic dependency is avoided by not hardcoding this class. A cyclic dependency is avoided by not hardcoding this class.
""" """
def __init__(self, broadcastable, name=None): def __init__(self, broadcastable, name=None, dtype=None):
if dtype != None or dtype != 'float32':
raise TypeError(self.__class__.__name__+' only support dtype float32 for now.')
self.broadcastable = tuple(broadcastable) self.broadcastable = tuple(broadcastable)
self.name = name self.name = name
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
......
...@@ -2568,12 +2568,14 @@ class Rebroadcast(Op): ...@@ -2568,12 +2568,14 @@ class Rebroadcast(Op):
would make x broadcastable in axis 0 would make x broadcastable in axis 0
and not broadcastable in axis 1 and not broadcastable in axis 1
See also the unbroadcast function. See also the unbroadcast function.
..note: work inplace and work for CudaNdarrayType
""" """
view_map = {0: [0]} view_map = {0: [0]}
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) self.axis = dict(axis)
def make_node(self, x): def make_node(self, x):
t = TensorType(dtype = x.type.dtype, t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b) broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
return Apply(self, [x], [t()]) return Apply(self, [x], [t()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论