提交 2480f1d1 authored 作者: james@crane's avatar james@crane

oplist, draft 1

上级 96a7d0f5
import sys
import gof
import tensor
def isOp(thing):
return hasattr(thing, 'perform')
def isOpClass(thing):
return hasattr(thing, 'perform') and not isinstance(thing, gof.Op)
def isOpConstructor(thing, module):
return hasattr(thing, 'perform') and isinstance(thing, gof.Op)\
or thing in getattr(module, '_constructor_list', [])
def print_title(title_string, under_char):
print title_string
print under_char * len(title_string)
def chomp(s):
"""interpret and left-align a docstring"""
......@@ -35,28 +42,53 @@ def chomp(s):
return "".join(r)
for module in [tensor]:
title = 'Ops in module: `%s`' % module.__name__
print title
print '-' * len(title)
import elemwise, scalar, sparse, tensor
print_title("Theano Op List", "~")
print ""
print ".. contents:: "
print ""
for module in [elemwise, scalar, sparse, tensor]:
print_title('module: `%s`' % module.__name__, '=')
print_title('Op Classes', '-')
for symbol_name in dir(module):
symbol = getattr(module, symbol_name)
if isOp(symbol):
if isOpClass(symbol) and symbol.__module__ == module.__name__:
print ""
print "- :api:`%s.%s`" % (module.__name__, symbol_name)
docstring = getattr(symbol, '__doc__', "")
if not docstring:
print " ", 'No documentation'
print " ", '(no doc)'
elif len(docstring) < 50:
print " ", chomp(docstring)
else:
print " ", chomp(docstring[:40]), "..."
# a little trailing whitespace
print ""
print_title('Op Constructors', '-')
for symbol_name in dir(module):
symbol = getattr(module, symbol_name)
if isOpConstructor(symbol, module) \
and symbol.__module__ == module.__name__:
print ""
print "- :api:`%s.%s`" % (module.__name__, symbol_name)
docstring = getattr(symbol, '__doc__', "")
if not docstring:
print " ", 'No documentation'
elif len(docstring) < 50:
print " ", chomp(docstring)
else:
print " ", chomp(docstring[:40]), "..."
# a little trailing whitespace
print ""
......@@ -26,6 +26,14 @@ from elemwise import Elemwise, DimShuffle, CAReduce, Sum
import tensor_random as random
_constructor_list = []
"""List of functions to be listed as op constructors in the oplist (`gen_oplist`, doc/oplist.txt)."""
def constructor(f):
"""Make `f` appear as a constructor in the oplist (`gen_oplist`, doc/oplist.txt)."""
_constructor_list.append(f)
return f
def as_tensor(x, name = None):
"""Return `x`, transformed into a `Tensor`
......@@ -485,6 +493,8 @@ def _elemwise(scalar_op, name):
straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
_constructor_list.append(straight)
# don't add the inplace versions, they aren't supposed to be part of the user interface
return straight, inplace
......@@ -520,6 +530,7 @@ class ScalarFromTensor(Op):
scalar_from_tensor = ScalarFromTensor()
@constructor
def cast(t, dtype):
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16,
......@@ -596,6 +607,7 @@ max_and_argmax = MaxAndArgmax()
@constructor
def max(x, axis=None):
"""Return indexes of maximum elements obtained by iterating over given axis
......@@ -606,6 +618,7 @@ def max(x, axis=None):
# but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[0]
@constructor
def argmax(x, axis=None):
"""Return maximum elements obtained by iterating over given axis
......@@ -665,9 +678,12 @@ tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
fill, fill_inplace = _elemwise(scal.second, 'fill')
@constructor
def ones_like(model):
#return Ones(model.type.ndim)(shape(model))
return fill(model, 1.0)
@constructor
def zeros_like(model):
#return Zeros(model.type.ndim)(shape(model))
return fill(model, 0.0)
......@@ -708,12 +724,14 @@ class Filler(gof.Op):
Zeros = functools.partial(Filler, 0)
Ones = functools.partial(Filler, 1)
@constructor
def zero():
"""
Return a scalar zero, e.g. for initializing sums.
"""
return Zeros(0)([])
@constructor
def one():
return Ones(0)([])
......@@ -721,9 +739,11 @@ def one():
tensor_copy = elemwise.Elemwise(scal.identity)
identity = elemwise.Elemwise(scal.identity, inplace_pattern = {0: [0]})
@constructor
def sum(input, axis = None):
return elemwise.Sum(axis)(input)
@constructor
def mean(input, axis = None):
s = sum(input, axis)
shp = shape(input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论