提交 0cbdf291 authored 作者: Frederic's avatar Frederic

Add Type.get_size() fct to help during profiling.

This add the size of the RandomType.
上级 25a45105
......@@ -603,12 +603,8 @@ class ProfileStats(object):
sh = self.variable_shape[out]
if isinstance(out.type, theano.sparse.SparseType):
v = "Sparse"
elif isinstance(out.type,
(theano.tensor.TensorType,
theano.sandbox.cuda.CudaNdarrayType)):
v = numpy.prod(sh)
dtype = str(out.dtype)
v *= numpy.dtype(dtype).itemsize
elif hasattr(out.type, 'get_size'):
v = out.type.get_size(sh)
sum_dense += v
else:
v = "Unknow"
......@@ -746,9 +742,7 @@ class ProfileStats(object):
if any([isinstance(out.type, theano.sparse.SparseType)
for out in node.outputs]):
size = "%10s" % "Sparse"
elif all([isinstance(out.type,
(theano.tensor.TensorType,
theano.sandbox.cuda.CudaNdarrayType))
elif all([hasattr(out.type, 'get_size')
for out in node.outputs]):
size = "%9dB" % node_outputs_size
else:
......
......@@ -417,6 +417,9 @@ class CudaNdarrayType(Type):
def c_compile_args(self):
return []
def get_size(self, shape_info):
return numpy.prod(shape_info, dtype=int) * numpy.dtype(self.dtype).itemsize
theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType
......
......@@ -1181,6 +1181,9 @@ class TensorType(Type):
"""
return numpy.zeros(shape, dtype=self.dtype)
def get_size(self, shape_info):
return numpy.prod(shape_info, dtype=int) * numpy.dtype(self.dtype).itemsize
theano.compile.ops.expandable_types += (TensorType,)
# Register TensorType C code for ViewOp.
......
......@@ -56,6 +56,22 @@ class RandomStateType(gof.Type):
return False
return True
def get_size(self, shape_info):
# The size is the data, that have constant size.
state = numpy.random.RandomState().get_state()
size = 0
for elem in state:
if isinstance(elem, str):
size += len(elem)
elif isinstance(elem, numpy.ndarray):
size += elem.size * elem.itemsize
elif isinstance(elem, int):
size += numpy.dtype("int").itemsize
elif isinstance(elem, float):
size += numpy.dtype("float").itemsize
else:
raise NotImplementedError()
return size
# Register RandomStateType's C code for ViewOp.
theano.compile.register_view_op_c_code(
RandomStateType,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论