提交 3169fbb0 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Changed the behaviour of inserting deep_copy_op in case borrow=True ( by

inserting view_op, a inplace version of deep_copy_op).
上级 5a86d812
...@@ -190,7 +190,31 @@ class DeepCopyOp(theano.gof.Op): ...@@ -190,7 +190,31 @@ class DeepCopyOp(theano.gof.Op):
else: else:
super(DeepCopyOp, self).c_code(node, name, inames, onames, sub) super(DeepCopyOp, self).c_code(node, name, inames, onames, sub)
class ViewOp(theano.gof.Op):
def __init__(self):
self.view_map={0:[0]}
def __str__(self):
return self.__class__.__name__
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def make_node(self, x):
return theano.gof.Apply(self, [x], [x.type()])
def perform( self, node, args, outs):
outs[0][0] = args[0]
deep_copy_op = DeepCopyOp() deep_copy_op = DeepCopyOp()
view_op = ViewOp()
DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate entries DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate entries
class Function(object): class Function(object):
...@@ -771,7 +795,10 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs): ...@@ -771,7 +795,10 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
# We could don't put deep copy if both outputs have borrow==True # We could don't put deep copy if both outputs have borrow==True
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow): # and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if env.outputs[j] in views_of_output_i: if env.outputs[j] in views_of_output_i:
env.change_input('output', i, deep_copy_op(env.outputs[i])) if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
env.change_input('output',i, view_op(env.outputs[i]))
else:
env.change_input('output', i, deep_copy_op(env.outputs[i]))
copied = True copied = True
break break
...@@ -786,8 +813,21 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs): ...@@ -786,8 +813,21 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
continue continue
# We could don't put deep_copy_op if the input and the output have borrow==True # We could don't put deep_copy_op if the input and the output have borrow==True
if input_j in views_of_output_i: if input_j in views_of_output_i:
env.change_input('output', i, deep_copy_op(env.outputs[i])) # We don't put deep_copy_op if the input and the output have borrow==True
break if input_j in env.inputs:
j = env.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
env.change_input('output',i, view_op(env.outputs[i]))
break
else:
env.change_input('output', i, deep_copy_op(env.outputs[i]))
break
elif wrapped_outputs[i].borrow:
env.change_input('output',i, view_op(env.outputs[i]))
break
else:
env.change_input('output', i, deep_copy_op(env.outputs[i]))
break
NODEFAULT = ['NODEFAULT'] NODEFAULT = ['NODEFAULT']
class FunctionMaker(object): class FunctionMaker(object):
......
...@@ -226,6 +226,7 @@ class In(SymbolicInput): ...@@ -226,6 +226,7 @@ class In(SymbolicInput):
autoname=autoname, autoname=autoname,
implicit=implicit) implicit=implicit)
self.value = value self.value = value
self.borrow = borrow
if self.implicit and value is None: if self.implicit and value is None:
raise TypeError('An implicit input must be given a default value') raise TypeError('An implicit input must be given a default value')
......
...@@ -304,11 +304,8 @@ class T_function(unittest.TestCase): ...@@ -304,11 +304,8 @@ class T_function(unittest.TestCase):
assert (out==4).all() assert (out==4).all()
out[0]=3 out[0]=3
out2 = f() out2 = f()
# Currently we don't do this optimization! assert out2 is out
# As this is a corner case that is not usefull for use assert (out2==3).all()
# We probably won't optimize it.
assert out2 is not out
assert (out2==4).all()
def test_borrow_input(self): def test_borrow_input(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论