提交 54645174 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use python any where we aren't working with ndarrays.

上级 1a4fec6e
...@@ -13,7 +13,6 @@ from six import string_types ...@@ -13,7 +13,6 @@ from six import string_types
from theano.compile.io import In from theano.compile.io import In
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
import numpy as np
import warnings import warnings
from theano import compat from theano import compat
...@@ -286,7 +285,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -286,7 +285,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
"input.") "input.")
# compute some features of the arguments: # compute some features of the arguments:
uses_tuple = np.any([isinstance(i, (list, tuple)) for i in inputs]) uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs])
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
......
...@@ -836,9 +836,9 @@ class Function(object): ...@@ -836,9 +836,9 @@ class Function(object):
in args_share_memory[j]], in args_share_memory[j]],
[self.input_storage[k].storage[0] for k [self.input_storage[k].storage[0] for k
in args_share_memory[j]]) in args_share_memory[j]])
if np.any([(var.type is i_var.type and if any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val)) var.type.may_share_memory(val, i_val))
for (var, val) in group_j]): for (var, val) in group_j]):
is_aliased = True is_aliased = True
args_share_memory[j].append(i) args_share_memory[j].append(i)
......
...@@ -948,7 +948,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -948,7 +948,7 @@ def scan_can_remove_outs(op, out_idxs):
added = False added = False
for pos, idx in enumerate(out_idxs): for pos, idx in enumerate(out_idxs):
if (out_idxs_mask[pos] and if (out_idxs_mask[pos] and
np.any([x in required_inputs for x in out_ins[idx]])): any([x in required_inputs for x in out_ins[idx]])):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]]) required_inputs += gof.graph.inputs([op.outputs[idx]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论