提交 02df7759 authored 作者: Frederic Bastien's avatar Frederic Bastien

more backport

上级 17d5c707
...@@ -110,7 +110,12 @@ class GpuElemwise(Op): ...@@ -110,7 +110,12 @@ class GpuElemwise(Op):
def _rehash(self): def _rehash(self):
items = self.inplace_pattern.items() items = self.inplace_pattern.items()
items.sort() items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items]) tuple_items=[k for k,v in items]
for k,v in items:
if isinstance(v, (tuple, list)):
tuple_items+=[tuple(v)]
else: tuple_items+=[v]
tuple_items = tuple(tuple_items)
h = hash(type(self)) ^ hash(self.scalar_op) ^ hash(tuple_items) h = hash(type(self)) ^ hash(self.scalar_op) ^ hash(tuple_items)
# don't change a code that has already been computed for this object # don't change a code that has already been computed for this object
assert h == getattr(self,'_hashval', h) assert h == getattr(self,'_hashval', h)
......
import StringIO, sys import StringIO, sys
import numpy
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, scalar from theano import tensor, scalar
...@@ -16,8 +17,16 @@ def debug(*msg): ...@@ -16,8 +17,16 @@ def debug(*msg):
def _logical_scalar(x): def _logical_scalar(x):
return all(x.type.broadcastable) return numpy.all(x.type.broadcastable)
def get_str_list_logical_scalar(node, value_str='ii_i%i_value', data_str='ii_i%i_data[0]'):
l=[]
for ipos, i in enumerate(node.inputs):
if _logical_scalar(i):
l+=[value_str%ipos]
else: l+=[data_str%ipos]
return l
class RecAlgo(object): class RecAlgo(object):
def c_src_kernel(self, node, nodename): def c_src_kernel(self, node, nodename):
nd = node.outputs[0].type.ndim nd = node.outputs[0].type.ndim
...@@ -83,7 +92,7 @@ class RecAlgo(object): ...@@ -83,7 +92,7 @@ class RecAlgo(object):
[scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs], [scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs],
[scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs]) [scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs])
, nodename + '_scalar_' , nodename + '_scalar_'
, [('ii_i%i_value' if _logical_scalar(i) else 'ii_i%i_data[0]')%ipos for ipos, i in enumerate(node.inputs)] , get_str_list_logical_scalar(node)
, ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)] , ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)]
, sub=dict(fail='return;')) #TODO: set a failure code somehow!!! , sub=dict(fail='return;')) #TODO: set a failure code somehow!!!
print >> sio, " ", task_code print >> sio, " ", task_code
...@@ -271,7 +280,7 @@ class NaiveAlgo(object): ...@@ -271,7 +280,7 @@ class NaiveAlgo(object):
[scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs], [scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs],
[scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs]) [scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs])
, nodename + '_scalar_' , nodename + '_scalar_'
, [('ii_i%i_value' if _logical_scalar(i) else 'ii_i%i_data[0]')%ipos for ipos, i in enumerate(node.inputs)] , get_str_list_logical_scalar(node)
, ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)] , ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)]
, sub=dict(fail='return;')) #TODO: set a failure code somehow!!! , sub=dict(fail='return;')) #TODO: set a failure code somehow!!!
print >> sio, " ", task_code print >> sio, " ", task_code
...@@ -394,7 +403,7 @@ class NaiveAlgo(object): ...@@ -394,7 +403,7 @@ class NaiveAlgo(object):
[scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs], [scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs],
[scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs]) [scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs])
, nodename + '_scalar_' , nodename + '_scalar_'
, [('value0[%i]' if _logical_scalar(i) else 'ii_i%i_data[0]')%ipos for ipos, i in enumerate(node.inputs)] , get_str_list_logical_scalar(node, value_str='value0[%i]')
, ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)] , ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)]
, sub=dict(fail='return;')) #TODO: set a failure code somehow!!! , sub=dict(fail='return;')) #TODO: set a failure code somehow!!!
print >> sio, " ", task_code print >> sio, " ", task_code
...@@ -605,7 +614,7 @@ class NaiveAlgo(object): ...@@ -605,7 +614,7 @@ class NaiveAlgo(object):
[scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs]) [scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs])
, nodename + '_scalar_' , nodename + '_scalar_'
#, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)] #, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)]
, [('ii_i%i_value' if _logical_scalar(i) else 'i%i_data[i]')%ipos for ipos, i in enumerate(node.inputs)] , get_str_list_logical_scalar(node, data_str='i%i_data[i]')
, ['o%i_data[i]'%ipos for ipos, i in enumerate(node.outputs)] , ['o%i_data[i]'%ipos for ipos, i in enumerate(node.outputs)]
, sub=dict(fail='return;')) #TODO: set a failure code somehow!!! , sub=dict(fail='return;')) #TODO: set a failure code somehow!!!
print >> sio, " ", task_code print >> sio, " ", task_code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论