提交 5222ae46 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6215 from abergeron/lol

Fix the copy for aliased function arguments
...@@ -13,7 +13,6 @@ from six import string_types ...@@ -13,7 +13,6 @@ from six import string_types
from theano.compile.io import In from theano.compile.io import In
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
import numpy as np
import warnings import warnings
from theano import compat from theano import compat
...@@ -286,7 +285,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -286,7 +285,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
"input.") "input.")
# compute some features of the arguments: # compute some features of the arguments:
uses_tuple = np.any([isinstance(i, (list, tuple)) for i in inputs]) uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs])
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
......
...@@ -836,9 +836,9 @@ class Function(object): ...@@ -836,9 +836,9 @@ class Function(object):
in args_share_memory[j]], in args_share_memory[j]],
[self.input_storage[k].storage[0] for k [self.input_storage[k].storage[0] for k
in args_share_memory[j]]) in args_share_memory[j]])
if np.any([(var.type is i_var.type and if any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val)) var.type.may_share_memory(val, i_val))
for (var, val) in group_j]): for (var, val) in group_j]):
is_aliased = True is_aliased = True
args_share_memory[j].append(i) args_share_memory[j].append(i)
...@@ -847,13 +847,13 @@ class Function(object): ...@@ -847,13 +847,13 @@ class Function(object):
if not is_aliased: if not is_aliased:
args_share_memory.append([i]) args_share_memory.append([i])
# Check for groups of more than one argument that share memory # Check for groups of more than one argument that share memory
for group in args_share_memory: for group in args_share_memory:
if len(group) > 1: if len(group) > 1:
# copy all but the first # copy all but the first
for idx in group[1:]: for j in group[1:]:
self.input_storage[i].storage[0] = copy.copy( self.input_storage[j].storage[0] = copy.copy(
self.input_storage[i].storage[0]) self.input_storage[j].storage[0])
# Check if inputs are missing, or if inputs were set more than once, or # Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit. # if we tried to provide inputs that are supposed to be implicit.
......
...@@ -948,7 +948,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -948,7 +948,7 @@ def scan_can_remove_outs(op, out_idxs):
added = False added = False
for pos, idx in enumerate(out_idxs): for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and if (out_idxs_mask[pos] and
np.any([x in required_inputs for x in out_ins[idx]])): any([x in required_inputs for x in out_ins[idx]])):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]]) required_inputs += gof.graph.inputs([op.outputs[idx]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论