提交 ede0a42b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Refractored the aliasing test by adding a method like `may_share_memory` to

the Theano type
上级 ae9825c5
...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en" ...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import copy_reg import copy_reg
import cPickle import cPickle
import itertools
import sys, time, copy import sys, time, copy
...@@ -537,19 +538,24 @@ class Function(object): ...@@ -537,19 +538,24 @@ class Function(object):
## Collect aliased inputs among the storage space ## Collect aliased inputs among the storage space
args_share_memory = [] args_share_memory = []
for i in xrange(len(self.input_storage)): for i in xrange(len(self.input_storage)):
if isinstance(self.input_storage[i].storage[0], i_var = self.maker.inputs[i].variable
numpy.ndarray): i_val = self.input_storage[i].storage[0]
if hasattr( i_var.type, 'may_share_memory'):
is_aliased = False is_aliased = False
for j in xrange(len(args_share_memory)): for j in xrange(len(args_share_memory)):
for k in args_share_memory[j]:
if numpy.may_share_memory( group_j = itertools.izip(
self.input_storage[i].storage[0] , [self.maker.inputs[k].variable for k
self.input_storage[k].storage[0]): in args_share_memory[j]],
is_aliased = True [self.input_storage[k].storage[0] for k
args_share_memory[j].append(i) in args_share_memory[j]])
break if numpy.any([ (var.type is i_var.type and
if is_aliased: var.type.may_share_memory(val,i_val)
break ) for (var,val) in group_j]):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased: if not is_aliased:
args_share_memory.append([i]) args_share_memory.append([i])
......
...@@ -174,6 +174,12 @@ class SparseType(gof.Type): ...@@ -174,6 +174,12 @@ class SparseType(gof.Type):
raise NotImplementedError() raise NotImplementedError()
return sp return sp
@staticmethod
def may_share_memory(a,b):
# This is Fred suggestion for a quick and dirty way of checking
# aliasing .. this can potentially be further refined (ticket #374)
return a is b
def make_variable(self, name = None): def make_variable(self, name = None):
return SparseVariable(self, name = name) return SparseVariable(self, name = name)
......
...@@ -472,6 +472,10 @@ class TensorType(Type): ...@@ -472,6 +472,10 @@ class TensorType(Type):
return type(self) == type(other) and other.dtype == self.dtype \ return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable and other.broadcastable == self.broadcastable
@staticmethod
def may_share_memory(a,b):
return numpy.may_share_memory(a,b)
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
#TODO: check to see if the dtype and shapes must match #TODO: check to see if the dtype and shapes must match
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论