提交 d4e4f6b3 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/compile/function_module.py

上级 e74a919e
...@@ -12,7 +12,7 @@ import six.moves.cPickle as pickle ...@@ -12,7 +12,7 @@ import six.moves.cPickle as pickle
from itertools import chain from itertools import chain
import time import time
import warnings import warnings
import numpy import numpy as np
import theano import theano
from theano import config, gof from theano import config, gof
...@@ -829,7 +829,7 @@ class Function(object): ...@@ -829,7 +829,7 @@ class Function(object):
in args_share_memory[j]], in args_share_memory[j]],
[self.input_storage[k].storage[0] for k [self.input_storage[k].storage[0] for k
in args_share_memory[j]]) in args_share_memory[j]])
if numpy.any([(var.type is i_var.type and if np.any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val)) var.type.may_share_memory(val, i_val))
for (var, val) in group_j]): for (var, val) in group_j]):
...@@ -1019,9 +1019,9 @@ def _pickle_Function(f): ...@@ -1019,9 +1019,9 @@ def _pickle_Function(f):
all_data = input_storage + inputs_data all_data = input_storage + inputs_data
for i, d_i in enumerate(all_data): for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data): for j, d_j in enumerate(all_data):
if ((i < j) and isinstance(d_i, numpy.ndarray) and if ((i < j) and isinstance(d_i, np.ndarray) and
isinstance(d_j, numpy.ndarray)): isinstance(d_j, np.ndarray)):
if numpy.may_share_memory(d_i, d_j): if np.may_share_memory(d_i, d_j):
if f.pickle_aliased_memory_strategy == 'warn': if f.pickle_aliased_memory_strategy == 'warn':
_logger.warning('aliased relationship between ' _logger.warning('aliased relationship between '
'Function arguments %s, %s ' 'Function arguments %s, %s '
...@@ -1041,7 +1041,7 @@ def _constructor_Function(maker, input_storage, inputs_data): ...@@ -1041,7 +1041,7 @@ def _constructor_Function(maker, input_storage, inputs_data):
assert len(f.input_storage) == len(inputs_data) assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data): for container, x in zip(f.input_storage, inputs_data):
assert (container.data is x) or \ assert (container.data is x) or \
(isinstance(x, numpy.ndarray) and (container.data == x).all()) or \ (isinstance(x, np.ndarray) and (container.data == x).all()) or \
(container.data == x) (container.data == x)
return f return f
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论