提交 50a61b52 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge, although mq should have helped preventing this...

...@@ -280,6 +280,15 @@ def local_gpu_shape(node): ...@@ -280,6 +280,15 @@ def local_gpu_shape(node):
return [gpu_shape(gpu_x)] return [gpu_shape(gpu_x)]
return False return False
@register_opt()
@local_optimizer([])
def local_gpu_rebroadcast(node):
'''rebroadcast(host_from_gpu(x)) -> host_from_gpu(rebroadcast(x))'''
if isinstance(node.op, tensor.Rebroadcast):
x, = node.inputs
if (x.owner and x.owner.op == host_from_gpu):
gpu_x = x.owner.inputs[0]
return [host_from_gpu(node.op(gpu_x))]
def cast(x, dtype): def cast(x, dtype):
stype = theano.scalar.Scalar(dtype) stype = theano.scalar.Scalar(dtype)
......
...@@ -2574,6 +2574,11 @@ class Rebroadcast(Op): ...@@ -2574,6 +2574,11 @@ class Rebroadcast(Op):
view_map = {0: [0]} view_map = {0: [0]}
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) self.axis = dict(axis)
def __str__(self):
broadcast_pattern = ['?' for i in range(1+numpy.max(self.axis.keys()))]
for k,v in self.axis.iteritems():
broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern))
def make_node(self, x): def make_node(self, x):
t = x.type.__class__(dtype = x.type.dtype, t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b) broadcastable = [self.axis.get(i, b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论