提交 e67bed64 authored 作者: Frederic's avatar Frederic

add theano.scalar.get_scalar_type(dtype) to cache Scalar(dtype) object.

This speed up optimization.
上级 dc91b8f5
...@@ -72,11 +72,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -72,11 +72,11 @@ class GpuElemwise(HideC, Elemwise):
try: try:
inps = [make_argument(i, 'i%d' % (n,)) for n, i in inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(node.inputs)] enumerate(node.inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in node.inputs] scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(node.outputs) if not n in self.inplace_pattern] enumerate(node.outputs) if not n in self.inplace_pattern]
scal_out = [scalar.Scalar(o.dtype) for o in node.outputs] scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out]) [o() for o in scal_out])
...@@ -99,11 +99,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -99,11 +99,11 @@ class GpuElemwise(HideC, Elemwise):
def generate_kernel(self, node, nodename): def generate_kernel(self, node, nodename):
inps = [make_argument(i, 'i%d' % (n,)) for n, i in inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(node.inputs)] enumerate(node.inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in node.inputs] scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(node.outputs) if not n in self.inplace_pattern] enumerate(node.outputs) if not n in self.inplace_pattern]
scal_out = [scalar.Scalar(o.dtype) for o in node.outputs] scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out]) [o() for o in scal_out])
......
...@@ -69,6 +69,18 @@ def upcast(dtype, *dtypes): ...@@ -69,6 +69,18 @@ def upcast(dtype, *dtypes):
return rval return rval
def get_scalar_type(dtype):
"""
Return an Scalar(dtype) object.
This cache objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
def as_scalar(x, name=None): def as_scalar(x, name=None):
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
...@@ -91,7 +103,7 @@ def constant(x): ...@@ -91,7 +103,7 @@ def constant(x):
# purpose typically. # purpose typically.
if hasattr(x, 'dtype'): if hasattr(x, 'dtype'):
assert x.ndim == 0 assert x.ndim == 0
return ScalarConstant(Scalar(str(x.dtype)), x) return ScalarConstant(get_scalar_type(str(x.dtype)), x)
if isinstance(x, builtin_float): if isinstance(x, builtin_float):
for dtype in ['float32', 'float64']: for dtype in ['float32', 'float64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
...@@ -99,7 +111,7 @@ def constant(x): ...@@ -99,7 +111,7 @@ def constant(x):
break break
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(get_scalar_type(str(x_.dtype)), x)
if isinstance(x, builtin_int): if isinstance(x, builtin_int):
for dtype in ['int8', 'int16', 'int32', 'int64']: for dtype in ['int8', 'int16', 'int32', 'int64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
...@@ -107,7 +119,7 @@ def constant(x): ...@@ -107,7 +119,7 @@ def constant(x):
break break
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(get_scalar_type(str(x_.dtype)), x)
if isinstance(x, builtin_complex): if isinstance(x, builtin_complex):
#TODO: We have added the complex type, so this should be tested #TODO: We have added the complex type, so this should be tested
raise NotImplementedError() raise NotImplementedError()
...@@ -457,18 +469,18 @@ theano.compile.register_view_op_c_code( ...@@ -457,18 +469,18 @@ theano.compile.register_view_op_c_code(
1) 1)
int8 = Scalar('int8') int8 = get_scalar_type('int8')
int16 = Scalar('int16') int16 = get_scalar_type('int16')
int32 = Scalar('int32') int32 = get_scalar_type('int32')
int64 = Scalar('int64') int64 = get_scalar_type('int64')
uint8 = Scalar('uint8') uint8 = get_scalar_type('uint8')
uint16 = Scalar('uint16') uint16 = get_scalar_type('uint16')
uint32 = Scalar('uint32') uint32 = get_scalar_type('uint32')
uint64 = Scalar('uint64') uint64 = get_scalar_type('uint64')
float32 = Scalar('float32') float32 = get_scalar_type('float32')
float64 = Scalar('float64') float64 = get_scalar_type('float64')
complex64 = Scalar('complex64') complex64 = get_scalar_type('complex64')
complex128 = Scalar('complex128') complex128 = get_scalar_type('complex128')
int_types = int8, int16, int32, int64 int_types = int8, int16, int32, int64
uint_types = uint8, uint16, uint32, uint64 uint_types = uint8, uint16, uint32, uint64
...@@ -584,7 +596,7 @@ class _scalar_py_operators: ...@@ -584,7 +596,7 @@ class _scalar_py_operators:
# The second is needed for Elemwise ops to work right # The second is needed for Elemwise ops to work right
if dtype is None: if dtype is None:
dtype = str(self.type.dtype) dtype = str(self.type.dtype)
return second(self, ScalarConstant(Scalar(dtype), 0)) return second(self, ScalarConstant(get_scalar_type(dtype), 0))
def astype(self, dtype): def astype(self, dtype):
return cast(self, dtype) return cast(self, dtype)
...@@ -628,7 +640,8 @@ complexs128 = _multi(complex128) ...@@ -628,7 +640,8 @@ complexs128 = _multi(complex128)
# necessary to use this same mechanism in other places as well in the future. # necessary to use this same mechanism in other places as well in the future.
class upcast_out(object): class upcast_out(object):
def __new__(self, *types): def __new__(self, *types):
return Scalar(dtype=Scalar.upcast(*types)), dtype = Scalar.upcast(*types)
return get_scalar_type(dtype),
class upgrade_to_float(object): class upgrade_to_float(object):
...@@ -644,7 +657,7 @@ class upgrade_to_float(object): ...@@ -644,7 +657,7 @@ class upgrade_to_float(object):
uint16: float32, uint16: float32,
uint32: float64, uint32: float64,
uint64: float64} uint64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])), for type in types])),
...@@ -656,7 +669,7 @@ class same_out(object): ...@@ -656,7 +669,7 @@ class same_out(object):
def upcast_out_no_complex(*types): def upcast_out_no_complex(*types):
if any([type in complex_types for type in types]): if any([type in complex_types for type in types]):
raise TypeError('complex type are not supported') raise TypeError('complex type are not supported')
return Scalar(dtype=Scalar.upcast(*types)), return get_scalar_type(dtype=Scalar.upcast(*types)),
def same_out_float_only(type): def same_out_float_only(type):
...@@ -1455,7 +1468,7 @@ def div_proxy(x, y): ...@@ -1455,7 +1468,7 @@ def div_proxy(x, y):
class TrueDiv(BinaryScalarOp): class TrueDiv(BinaryScalarOp):
def output_types(self, types): def output_types(self, types):
if all(t in discrete_types for t in types): if all(t in discrete_types for t in types):
return [Scalar(config.floatX)] return [get_scalar_type(config.floatX)]
else: else:
return super(TrueDiv, self).output_types(types) return super(TrueDiv, self).output_types(types)
......
...@@ -59,7 +59,7 @@ def safe_new(x, tag='', dtype=None): ...@@ -59,7 +59,7 @@ def safe_new(x, tag='', dtype=None):
# making the pushout optimization fail # making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable): elif isinstance(x, scalar.ScalarVariable):
if dtype: if dtype:
nw_x = scalar.Scalar(dtype=dtype)() nw_x = scalar.get_scalar_type(dtype=dtype)()
else: else:
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
......
...@@ -991,7 +991,7 @@ class ScalarFromTensor(Op): ...@@ -991,7 +991,7 @@ class ScalarFromTensor(Op):
assert t.type.broadcastable == () assert t.type.broadcastable == ()
return Apply(self, return Apply(self,
[t], [t],
[scal.Scalar(dtype=t.type.dtype).make_variable()]) [scal.get_scalar_type(dtype=t.type.dtype).make_variable()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
s, = inp s, = inp
......
...@@ -8,7 +8,7 @@ import theano ...@@ -8,7 +8,7 @@ import theano
from theano import gof from theano import gof
from theano.gof import Apply, Op from theano.gof import Apply, Op
from theano import scalar from theano import scalar
from theano.scalar import Scalar from theano.scalar import Scalar, get_scalar_type
from theano.printing import pprint from theano.printing import pprint
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.tensor.utils import hash_from_dict from theano.tensor.utils import hash_from_dict
...@@ -515,7 +515,7 @@ class Elemwise(Op): ...@@ -515,7 +515,7 @@ class Elemwise(Op):
""" """
inputs = map(as_tensor_variable, inputs) inputs = map(as_tensor_variable, inputs)
shadow = self.scalar_op.make_node( shadow = self.scalar_op.make_node(
*[Scalar(dtype=i.type.dtype)() for i in inputs]) *[get_scalar_type(dtype=i.type.dtype)() for i in inputs])
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
...@@ -718,7 +718,7 @@ class Elemwise(Op): ...@@ -718,7 +718,7 @@ class Elemwise(Op):
def as_scalar(t): def as_scalar(t):
if isinstance(t.type, (NullType, DisconnectedType)): if isinstance(t.type, (NullType, DisconnectedType)):
return t return t
return Scalar(t.type.dtype)() return get_scalar_type(t.type.dtype)()
scalar_inputs = map(as_scalar, inputs) scalar_inputs = map(as_scalar, inputs)
scalar_ograds = map(as_scalar, ograds) scalar_ograds = map(as_scalar, ograds)
...@@ -1039,9 +1039,9 @@ class Elemwise(Op): ...@@ -1039,9 +1039,9 @@ class Elemwise(Op):
# We generate the C code of the inner loop using the scalar op # We generate the C code of the inner loop using the scalar op
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
Apply(self.scalar_op, Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() [get_scalar_type(dtype=input.type.dtype)()
for input in node.inputs], for input in node.inputs],
[Scalar(dtype=output.type.dtype)() [get_scalar_type(dtype=output.type.dtype)()
for output in node.outputs]), for output in node.outputs]),
nodename + '_scalar_', nodename + '_scalar_',
["%s_i" % s for s in _inames], ["%s_i" % s for s in _inames],
...@@ -1161,11 +1161,11 @@ class Elemwise(Op): ...@@ -1161,11 +1161,11 @@ class Elemwise(Op):
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() for input in node.inputs], [get_scalar_type(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype)() for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(Scalar(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
if all(version): if all(version):
return tuple(version) return tuple(version)
else: else:
...@@ -1531,9 +1531,9 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1531,9 +1531,9 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
task1_code = self.scalar_op.c_code( task1_code = self.scalar_op.c_code(
Apply( Apply(
self.scalar_op, self.scalar_op,
[Scalar(dtype=input.type.dtype)() [get_scalar_type(dtype=input.type.dtype)()
for input in (node.inputs * 2)], for input in (node.inputs * 2)],
[Scalar(dtype=output.type.dtype)() [get_scalar_type(dtype=output.type.dtype)()
for input in node.outputs]), for input in node.outputs]),
None, None,
["%s_i" % aname, "%s_i" % inames[0]], ["%s_i" % aname, "%s_i" % inames[0]],
...@@ -1583,11 +1583,11 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1583,11 +1583,11 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() for input in node.inputs], [get_scalar_type(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype)() for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(Scalar(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
if all(version): if all(version):
return tuple(version) return tuple(version)
else: else:
......
...@@ -3583,7 +3583,7 @@ def local_pow_specialize_device(node): ...@@ -3583,7 +3583,7 @@ def local_pow_specialize_device(node):
# 512 is too small for the cpu and too big for some gpu! # 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512: if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym] pow2 = [xsym]
pow2_scal = [theano.scalar.Scalar(xsym.dtype)()] pow2_scal = [theano.scalar.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y) y_to_do = abs(y)
for i in xrange(int(numpy.log2(y_to_do))): for i in xrange(int(numpy.log2(y_to_do))):
pow2.append(T.sqr(pow2[i])) pow2.append(T.sqr(pow2[i]))
...@@ -4638,7 +4638,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4638,7 +4638,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input: elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp = scalar.Scalar(ii.dtype).make_variable() tmp = scalar.get_scalar_type(ii.dtype).make_variable()
try: try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0] tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError: except AttributeError:
...@@ -4692,7 +4692,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4692,7 +4692,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
if inputs.count(i) == node.inputs.count(i): if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = scalar.Scalar(i.dtype).make_variable() s = scalar.get_scalar_type(i.dtype).make_variable()
try: try:
if theano.config.compute_test_value != 'off': if theano.config.compute_test_value != 'off':
v = gof.op.get_test_value(i) v = gof.op.get_test_value(i)
......
...@@ -318,11 +318,11 @@ class Subtensor(Op): ...@@ -318,11 +318,11 @@ class Subtensor(Op):
if (isinstance(entry, gof.Variable) if (isinstance(entry, gof.Variable)
and entry.type in tensor_types and entry.type in tensor_types
and numpy.all(entry.type.broadcastable)): and numpy.all(entry.type.broadcastable)):
return scal.Scalar(entry.type.dtype) return scal.get_scalar_type(entry.type.dtype)
elif (isinstance(entry, gof.Type) elif (isinstance(entry, gof.Type)
and entry in tensor_types and entry in tensor_types
and numpy.all(entry.broadcastable)): and numpy.all(entry.broadcastable)):
return scal.Scalar(entry.dtype) return scal.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
b = entry.stop b = entry.stop
......
...@@ -240,7 +240,7 @@ class TensorType(Type): ...@@ -240,7 +240,7 @@ class TensorType(Type):
% (self.__class__.__name__, self.dtype)) % (self.__class__.__name__, self.dtype))
def to_scalar_type(self): def to_scalar_type(self):
return scal.Scalar(dtype=self.dtype) return scal.get_scalar_type(dtype=self.dtype)
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
...@@ -538,23 +538,23 @@ class TensorType(Type): ...@@ -538,23 +538,23 @@ class TensorType(Type):
def c_headers(self): def c_headers(self):
"""Override `CLinkerObject.c_headers` """ """Override `CLinkerObject.c_headers` """
return scal.Scalar(self.dtype).c_headers() return scal.get_scalar_type(self.dtype).c_headers()
def c_libraries(self): def c_libraries(self):
return scal.Scalar(self.dtype).c_libraries() return scal.get_scalar_type(self.dtype).c_libraries()
def c_compile_args(self): def c_compile_args(self):
return scal.Scalar(self.dtype).c_compile_args() return scal.get_scalar_type(self.dtype).c_compile_args()
def c_support_code(self): def c_support_code(self):
"""Override `CLinkerObject.c_support_code` """ """Override `CLinkerObject.c_support_code` """
return scal.Scalar(self.dtype).c_support_code() return scal.get_scalar_type(self.dtype).c_support_code()
def c_init_code(self): def c_init_code(self):
return scal.Scalar(self.dtype).c_init_code() return scal.get_scalar_type(self.dtype).c_init_code()
def c_code_cache_version(self): def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version() scalar_version = scal.get_scalar_type(self.dtype).c_code_cache_version()
if scalar_version: if scalar_version:
return (11,) + scalar_version return (11,) + scalar_version
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论