提交 5222ae46 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6215 from abergeron/lol

Fix the copy for aliased function arguments
......@@ -13,7 +13,6 @@ from six import string_types
from theano.compile.io import In
from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc
import numpy as np
import warnings
from theano import compat
......@@ -286,7 +285,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
"input.")
# compute some features of the arguments:
uses_tuple = np.any([isinstance(i, (list, tuple)) for i in inputs])
uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs])
uses_updates = bool(updates)
uses_givens = bool(givens)
......
......@@ -836,9 +836,9 @@ class Function(object):
in args_share_memory[j]],
[self.input_storage[k].storage[0] for k
in args_share_memory[j]])
if np.any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val))
for (var, val) in group_j]):
if any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val))
for (var, val) in group_j]):
is_aliased = True
args_share_memory[j].append(i)
......@@ -847,13 +847,13 @@ class Function(object):
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0])
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
......
......@@ -948,7 +948,7 @@ def scan_can_remove_outs(op, out_idxs):
added = False
for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and
np.any([x in required_inputs for x in out_ins[idx]])):
any([x in required_inputs for x in out_ins[idx]])):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论