提交 efb4786e authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5032 from nouiz/simplify

Speed up the canonizer for big list of num/denum
...@@ -6,7 +6,6 @@ types that it can raise. ...@@ -6,7 +6,6 @@ types that it can raise.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import OrderedDict from collections import OrderedDict
import sys
import time import time
import traceback import traceback
...@@ -260,7 +259,7 @@ class FunctionGraph(utils.object2): ...@@ -260,7 +259,7 @@ class FunctionGraph(utils.object2):
""" """
return r.clients return r.clients
def __add_clients__(self, r, new_clients): def __add_client__(self, r, new_client):
""" """
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
...@@ -268,21 +267,19 @@ class FunctionGraph(utils.object2): ...@@ -268,21 +267,19 @@ class FunctionGraph(utils.object2):
---------- ----------
r r
Variable. Variable.
new_clients new_client
List of (node, i) pairs such that node.inputs[i] is r. (node, i) pair such that node.inputs[i] is r.
""" """
if set(r.clients).intersection(set(new_clients)): # Ne need to do the assert as it is always True. The logic
print('ERROR: clients intersect!', file=sys.stderr) # that call __add_client__ is valid. When the client list is
print(' RCLIENTS of', r, [(n, i, type(n), id(n)) # long, the check it time consuming, so we don't enable it by
for n, i in r.clients], file=sys.stderr) # default.
print(' NCLIENTS of', r, [(n, i, type(n), id(n)) # assert not new_client in r.clients
for n, i in new_clients], file=sys.stderr) r.clients.append(new_client)
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, def __remove_client__(self, r, client_to_remove,
prune=True, reason=None): prune=True, reason=None):
""" """
Removes all from the clients list of r. Removes all from the clients list of r.
...@@ -296,8 +293,8 @@ class FunctionGraph(utils.object2): ...@@ -296,8 +293,8 @@ class FunctionGraph(utils.object2):
---------- ----------
r : Variable r : Variable
The clients of r will be removed. The clients of r will be removed.
clients_to_remove : List of (op, i) pairs client_to_remove : (op, i) pair
List of (op, i) pairs such that node.inputs[i] is not r anymore. (op, i) pair such that node.inputs[i] is not r anymore.
prune : bool prune : bool
If prune is True, it remove r from this fgraph if it don't If prune is True, it remove r from this fgraph if it don't
have clients left. have clients left.
...@@ -311,9 +308,11 @@ class FunctionGraph(utils.object2): ...@@ -311,9 +308,11 @@ class FunctionGraph(utils.object2):
clients_to_remove and prune=True will remove r. clients_to_remove and prune=True will remove r.
""" """
for entry in clients_to_remove: if client_to_remove:
r.clients.remove(entry) r.clients.remove(client_to_remove)
assert entry not in r.clients # an op,i pair should be unique # entry should be uniq in r. No need to assert it as it is
# already asserted in __add_client__.
# assert entry not in r.clients
if r.clients: if r.clients:
return False return False
if not prune: if not prune:
...@@ -333,8 +332,8 @@ class FunctionGraph(utils.object2): ...@@ -333,8 +332,8 @@ class FunctionGraph(utils.object2):
self.execute_callbacks('on_prune', apply_node, reason) self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs): for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], self.__remove_client__(input, (apply_node, i),
reason=reason) reason=reason)
# variable should not have any clients. # variable should not have any clients.
# assert not variable.clients # assert not variable.clients
...@@ -431,7 +430,7 @@ class FunctionGraph(utils.object2): ...@@ -431,7 +430,7 @@ class FunctionGraph(utils.object2):
if input not in self.variables: if input not in self.variables:
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__add_clients__(input, [(node, i)]) self.__add_client__(input, (node, i))
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
...@@ -470,15 +469,15 @@ class FunctionGraph(utils.object2): ...@@ -470,15 +469,15 @@ class FunctionGraph(utils.object2):
return return
self.__import_r__(new_r, reason=reason) self.__import_r__(new_r, reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_client__(new_r, (node, i))
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_client__(r, (node, i), False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the # However it may introduce cycles to the graph, in which case the
# transaction will be reverted later. # transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, self.execute_callbacks('on_change_input', node, i,
r, new_r, reason=reason) r, new_r, reason=reason)
if prune: if prune:
self.__remove_clients__(r, [], True, reason=reason) self.__remove_client__(r, None, True, reason=reason)
# replace # # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
......
...@@ -29,18 +29,21 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d, ...@@ -29,18 +29,21 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from .type import (GpuArrayType, GpuArrayConstant, get_context, from .type import (GpuArrayType, GpuArrayConstant, get_context,
ContextNotDefined) ContextNotDefined, move_to_gpu)
from .basic_ops import (as_gpuarray_variable, infer_context_name, from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu, host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost, HostFromGpu, GpuFromHost,
GpuSplit, GpuContiguous, gpu_contiguous, GpuSplit, GpuContiguous, gpu_contiguous,
GpuAlloc, GpuAllocEmpty, GpuReshape, GpuAlloc, GpuAllocEmpty, GpuReshape,
GpuEye, gpu_join, GpuJoin, gpu_alloc_empty, gpu_alloc, gpu_from_host) GpuEye, gpu_join, GpuJoin, gpu_alloc_empty,
gpu_alloc, gpu_from_host)
from .blas import (gpu_dot22, GpuGemm, GpuGer, GpuGemmBatch, from .blas import (gpu_dot22, GpuGemm, GpuGer, GpuGemmBatch,
gpugemm_no_inplace, gpugemm_inplace, gpugemmbatch_no_inplace, gpugemm_no_inplace, gpugemm_inplace,
gpugemmbatch_no_inplace,
gpugemv_no_inplace, gpugemv_inplace) gpugemv_no_inplace, gpugemv_inplace)
from .blocksparse import (GpuSparseBlockGemv, GpuSparseBlockOuter, from .blocksparse import (GpuSparseBlockGemv, GpuSparseBlockOuter,
gpu_sparse_block_outer, gpu_sparse_block_outer_inplace, gpu_sparse_block_outer,
gpu_sparse_block_outer_inplace,
gpu_sparse_block_gemv, gpu_sparse_block_gemv_inplace) gpu_sparse_block_gemv, gpu_sparse_block_gemv_inplace)
from .nnet import (gpu_crossentropy_softmax_1hot_with_bias_dx, from .nnet import (gpu_crossentropy_softmax_1hot_with_bias_dx,
gpu_crossentropy_softmax_argmax_1hot_with_bias, gpu_crossentropy_softmax_argmax_1hot_with_bias,
...@@ -239,9 +242,8 @@ class InputToGpuOptimizer(Optimizer): ...@@ -239,9 +242,8 @@ class InputToGpuOptimizer(Optimizer):
target = getattr(input.tag, 'target', None) target = getattr(input.tag, 'target', None)
if target == 'cpu': if target == 'cpu':
continue continue
# Do not move *int* scalar to the GPU.
if (isinstance(input.type, tensor.TensorType) and if (isinstance(input.type, tensor.TensorType) and
input.ndim == 0 and 'int' in input.dtype): not move_to_gpu(input)):
continue continue
try: try:
...@@ -297,10 +299,7 @@ class GraphToGPU(Optimizer): ...@@ -297,10 +299,7 @@ class GraphToGPU(Optimizer):
# Iterating through inputs of graph # Iterating through inputs of graph
target = infer_context_name(*fgraph.inputs) target = infer_context_name(*fgraph.inputs)
for i in fgraph.inputs: for i in fgraph.inputs:
# Do not move *int* scalar to the GPU. if isinstance(i.type, tensor.TensorType) and move_to_gpu(i):
if (isinstance(i.type, tensor.TensorType) and
(i.ndim > 0 or 'int' not in i.dtype) and
"complex" not in i.dtype):
mapping[i] = i.transfer(getattr(i.tag, 'target', target)) mapping[i] = i.transfer(getattr(i.tag, 'target', target))
else: else:
mapping[i] = i mapping[i] = i
......
...@@ -22,6 +22,26 @@ except ImportError: ...@@ -22,6 +22,26 @@ except ImportError:
_context_reg = {} _context_reg = {}
def move_to_gpu(data):
"""
Do we want to move this computation to the GPU?
Currently, we don't move complex and scalar int.
Parameters
----------
data : numpy.ndarray or TensorVariable
(it must have dtype and ndim parameter)
"""
# We don't support complex on the GPU
if str(data.dtype) in tensor.basic.complex_dtypes:
return False
# We don't want scalar int on the GPU.
if data.ndim == 0 and str(data.dtype) in tensor.basic.discrete_dtypes:
return False
return True
class ContextNotDefined(ValueError): class ContextNotDefined(ValueError):
pass pass
...@@ -561,16 +581,22 @@ class GpuArraySharedVariable(_operators, SharedVariable): ...@@ -561,16 +581,22 @@ class GpuArraySharedVariable(_operators, SharedVariable):
GpuArrayType.SharedVariable = GpuArraySharedVariable GpuArrayType.SharedVariable = GpuArraySharedVariable
notset = object()
def gpuarray_shared_constructor(value, name=None, strict=False, def gpuarray_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, allow_downcast=None, borrow=False,
broadcastable=None, target=None): broadcastable=None, target=notset):
""" """
SharedVariable constructor for GpuArrayType. SharedVariable constructor for GpuArrayType.
See :func:`theano.shared`. See :func:`theano.shared`.
:target: default None
The device target. As None is a valid value and we need to
differentiate from the parameter notset and None, we use a
notset object.
""" """
if target == 'gpu' or target == 'cpu': if target == 'gpu' or target == 'cpu':
raise TypeError('not for me') raise TypeError('not for me')
...@@ -578,6 +604,10 @@ def gpuarray_shared_constructor(value, name=None, strict=False, ...@@ -578,6 +604,10 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)): if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)):
raise TypeError('ndarray or GpuArray required') raise TypeError('ndarray or GpuArray required')
if target is notset:
target = None
if not move_to_gpu(value):
raise TypeError('We do not move that data by default to the GPU')
try: try:
get_context(target) get_context(target)
except ContextNotDefined: except ContextNotDefined:
......
...@@ -4751,13 +4751,17 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4751,13 +4751,17 @@ class Canonizer(gof.LocalOptimizer):
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, Variable): if isinstance(v, Constant):
try: if getattr(v.tag, 'unique_value', None) is not None:
# As the constant folding is in the canonicalize phase, data = v.tag.unique_value
# We don't need to check all the graph each time. else:
return get_scalar_constant_value(v, only_process_constants=True) data = v.data
except NotScalarConstantError: if data.ndim == 0:
return data
else:
return None return None
elif isinstance(v, Variable):
return None
else: else:
return v return v
...@@ -4790,10 +4794,25 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4790,10 +4794,25 @@ class Canonizer(gof.LocalOptimizer):
| [a, b], [c, d] -> [a, b], [c, d] | [a, b], [c, d] -> [a, b], [c, d]
""" """
for v in list(num): ln = len(num)
if v in denum: ld = len(denum)
num.remove(v) if (ld > 2 and ln > 2):
denum.remove(v) # Faster version for "big" inputs.
while True:
s = set(num)
# Inputs can appear multiple times
redo = len(s) != len(num)
inter = s.intersection(denum)
for v in inter:
num.remove(v)
denum.remove(v)
if not redo or not inter:
break
else:
for v in list(num):
if v in denum:
num.remove(v)
denum.remove(v)
return num, denum return num, denum
def simplify_constants(self, orig_num, orig_denum, out_type=None): def simplify_constants(self, orig_num, orig_denum, out_type=None):
...@@ -4815,9 +4834,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4815,9 +4834,8 @@ class Canonizer(gof.LocalOptimizer):
| [x, 2, y], [z, 2] -> [x, y], [z] | [x, 2, y], [z, 2] -> [x, y], [z]
""" """
# Lists representing the numerator and denumerator # Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum) num, denum = [], []
# Lists representing the *constant* elements of num and denum # Lists representing the *constant* elements of num and denum
numct, denumct = [], [] numct, denumct = [], []
...@@ -4826,15 +4844,16 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4826,15 +4844,16 @@ class Canonizer(gof.LocalOptimizer):
ct = self.get_constant(v) ct = self.get_constant(v)
if ct is not None: if ct is not None:
# We found a constant in the numerator! # We found a constant in the numerator!
# We remove it from num
num.remove(v)
# We add it to numct # We add it to numct
numct.append(ct) numct.append(ct)
else:
num.append(v)
for v in orig_denum: for v in orig_denum:
ct = self.get_constant(v) ct = self.get_constant(v)
if ct is not None: if ct is not None:
denum.remove(v)
denumct.append(ct) denumct.append(ct)
else:
denum.append(v)
if self.use_reciprocal or num: if self.use_reciprocal or num:
# This will calculate either: # This will calculate either:
......
...@@ -89,16 +89,6 @@ def test_gc_never_pickles_temporaries(): ...@@ -89,16 +89,6 @@ def test_gc_never_pickles_temporaries():
# assert that f() didn't cause the function to grow # assert that f() didn't cause the function to grow
# allow_gc should leave the function un-changed by calling # allow_gc should leave the function un-changed by calling
if len_pre_f != len_post_f:
for i in range(len_pre_f//100):
p1 = pre_f[i*100:(i+1)*100]
p2 = post_f[i*100:(i+1)*100]
if p1 != p2:
print(i)
print("p1")
print(p1)
print("p2")
print(p2)
assert len_pre_f == len_post_f, (len_pre_f, len_post_f) assert len_pre_f == len_post_f, (len_pre_f, len_post_f)
# assert that g() didn't cause g to grow because temporaries # assert that g() didn't cause g to grow because temporaries
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论