提交 ae18fb3c authored 作者: Frederic Bastien's avatar Frederic Bastien

manual merge.

......@@ -73,6 +73,8 @@ Software Requirements
- g++, python-dev (optional, to compile generated C code)
- nose, for testing
- `psyco <http://psyco.sourceforge.net/>`__ can make your python code much faster, if you are on a 32-bit x86 architecture. If you use compiled C code, this can be less important.
Downloading Theano
......@@ -89,7 +91,7 @@ Get the source and run the tests like this:
hg clone http://pylearn.org/hg/theano theano
cd theano
nosetest
nosetests
To update your library to the latest on pylearn.org, change directory (`cd`) to this `theano` folder and type
......
......@@ -25,7 +25,7 @@ Our project uses the name to honour the ancient Greek mathematician.
Overview
========
**To get up & running quickly** see README_.
**To get up and running quickly** see README_.
All **documentation** can be reached from the `Theano Project Documentation Overview`_.
......
......@@ -44,7 +44,9 @@ from compile import \
predefined_modes, predefined_linkers, predefined_optimizers, \
FunctionMaker, function, OpFromGraph, \
Component, External, Member, KitComponent, Method, \
Composite, ComponentList, Module, FancyModule
Composite, ComponentList, ComponentDict, Module
FancyModule = Module
from printing import \
pprint, pp
......
......@@ -10,34 +10,73 @@ import function_module as F
def join(*args):
"""
Creates a string representation for the given names:
join('a', 'b', 'c') => 'a.b.c'
"""
return ".".join(arg for arg in args if arg)
def split(sym, n=-1):
"""
Gets the names from their joined representation
split('a.b.c') => 'a', 'b', 'c'
Returns the n first names, if n==-1 returns all of them.
"""
return sym.split('.', n)
def canonicalize(name):
"""
Splits the name and converts each name to the
right type (e.g. "2" -> 2)
"""
if isinstance(name, str):
name = split(name)
def convert(x):
try:
return int(x)
except ValueError:
except (ValueError, TypeError):
return x
return map(convert, name)
class AllocationError(Exception):
"""
Exception raised when a Result has no associated storage.
"""
pass
class BindError(Exception):
"""
Exception raised when a Component is already bound and we try to
bound it again.
"""
pass
class Component(object):
"""
Base class for the various kinds of components which are not
structural but may be meaningfully used in structures (Member,
Method, etc.)
"""
def __init__(self):
self.__dict__['_name'] = ''
self.__dict__['parent'] = None
def bind(self, parent, name, dup_ok=True):
"""
Marks this component as belonging to the parent (the parent is
typically a Composite instance). The component can be accessed
through the parent with the specified name. If dup_ok is True
and that this Component is already bound, a duplicate of the
component will be made using the dup() method and the
duplicate will be bound instead of this Component. If dup_ok
is False and this Component is already bound, a BindError wil
be raised.
bind() returns the Component instance which has been bound to
the parent. For an unbound instance, this will usually be
self.
"""
if self.bound():
if dup_ok:
try:
......@@ -54,21 +93,48 @@ class Component(object):
return self
def bound(self):
"""
Returns True if this Component instance is bound to a
Composite.
"""
return self.parent is not None
def allocate(self, memo):
"""
Populates the memo dictionary with Result -> Container
pairings.
"""
raise NotImplementedError
def build(self, mode, memo):
"""
Makes an instance of this Component using the mode provided
and taking the containers in the memo dictionary.
A Component which builds nothing may return None.
"""
raise NotImplementedError
def make_no_init(self, mode='FAST_COMPILE'):
"""
Allocates the necessary containers using allocate() and uses
build() with the provided mode to make an instance which will
be returned. The initialize() method of the instance will not
be called.
"""
memo = {}
self.allocate(memo)
rval = self.build(mode, memo)
return rval
def make(self, *args, **kwargs):
"""
Allocates the necessary containers using allocate() and uses
build() to make an instance which will be returned. The
initialize() method of the instance will be called with the
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
mode = kwargs.pop('mode', 'FAST_COMPILE')
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
......@@ -82,20 +148,34 @@ class Component(object):
return self.__class__.__name__
def pretty(self, **kwargs):
"""
Returns a pretty representation of this Component, suitable
for reading.
"""
raise NotImplementedError
def __get_name__(self):
"""
Getter for self.name
"""
return self._name
def __set_name__(self, name):
"""
Setter for self.name
"""
self._name = name
name = property(lambda self: self.__get_name__(),
lambda self, value: self.__set_name__(value))
lambda self, value: self.__set_name__(value),
"Contains the name of this Component")
class _RComponent(Component):
"""
Base class for a Component wrapping a Result. For internal use.
"""
def __init__(self, r):
super(_RComponent, self).__init__()
......@@ -119,12 +199,19 @@ class _RComponent(Component):
class External(_RComponent):
"""
External represents a Result which comes from somewhere else
(another module) or is a temporary calculation.
"""
def allocate(self, memo):
# nothing to allocate
return None
def build(self, mode, memo):
"""
Builds nothing.
"""
return None
def pretty(self, **kwargs):
......@@ -136,8 +223,19 @@ class External(_RComponent):
class Member(_RComponent):
"""
Member represents a Result which is a state of a Composite. That
Result will be accessible from a built Composite and it is
possible to do updates on Members.
Member builds a gof.Container.
"""
def allocate(self, memo):
"""
If the memo does not have a Container associated to this
Member's Result, instantiates one and sets it in the memo.
"""
r = self.r
if memo and r in memo:
return memo[r]
......@@ -146,6 +244,9 @@ class Member(_RComponent):
return rval
def build(self, mode, memo):
"""
Returns the Container associated to this Member's Result.
"""
return memo[self.r]
......@@ -153,6 +254,20 @@ class Member(_RComponent):
class Method(Component):
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
"""
Method is a declaration of a function. It contains inputs,
outputs, updates and kits. If the Method is part of a
Composite which holds references to Members, the Method may
use them without declaring them in the inputs, outputs or
updates list.
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
Method builds a Function (same structure as a call to
theano.function)
"""
super(Method, self).__init__()
self.inputs = inputs
self.outputs = outputs
......@@ -165,6 +280,9 @@ class Method(Component):
return rval
def resolve(self, name):
"""
Resolves the name of an input or output in the parent.
"""
if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.')
result = self.parent.resolve(name)
......@@ -175,16 +293,23 @@ class Method(Component):
def resolve_result(self, x):
if isinstance(x, gof.Result):
return x
elif isinstance(x, _RComponent):
return x.r
else:
return self.resolve(x).r
def resolve_all(self):
if not isinstance(self.inputs, (list, tuple)):
"""
Resolves all inputs, outputs and updates that were given as
strings so that the fields contain the corresponding Result
instances instead.
"""
if isinstance(self.inputs, (gof.Result, str)):
inputs = [self.inputs]
else:
inputs = self.inputs
inputs = list(self.inputs)
self.inputs = [self.resolve_result(input) for input in inputs]
if isinstance(self.outputs, (list, tuple)):
if isinstance(self.outputs, (list, tuple, ComponentList)):
self.outputs = [self.resolve_result(output) for output in self.outputs]
else:
self.outputs = self.resolve_result(self.outputs)
......@@ -195,11 +320,22 @@ class Method(Component):
self.updates[k] = v
def allocate(self, memo):
"""
Method allocates nothing.
"""
return None
def build(self, mode, memo, allocate_all = False):
self.resolve_all()
"""
Produces a function. If allocate_all is True, storage will be
allocated for all needed Results, even if there is no
associated storage for them in the memo. If allocate_all is
False, storage will only be allocated for Results that are
reachable from the inputs list.
"""
self.resolve_all() # resolve all so we don't have to mess with strings
def get_storage(r, require = False):
# If require is True, we can only get storage from the memo.
try:
return memo[r]
except KeyError:
......@@ -209,11 +345,13 @@ class Method(Component):
' enclosing module or of one of its submodules.' % (r, self))
else:
return gof.Container(r, storage = [None])
# Wrap the inputs in In instances.
inputs = self.inputs
inputs = [io.In(result = input,
value = get_storage(input),
mutable = False)
for input in inputs]
# Add the members to update to the inputs.
inputs += [io.In(result = k,
update = v,
value = get_storage(k, not allocate_all),
......@@ -222,13 +360,20 @@ class Method(Component):
for k, v in self.updates.iteritems()]
outputs = self.outputs
_inputs = [x.result for x in inputs]
# Grab the results that are not accessible from either the inputs or the updates.
for input in gof.graph.inputs((list(outputs) if isinstance(outputs, (list, tuple)) else [outputs])
+ [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value):
# Add this input to the inputs; we require that storage already exists for them,
# but otherwise they are immutable.
inputs += [io.In(result = input,
value = get_storage(input, not allocate_all),
mutable = False)]
# Add the kits to the input. The kit should be associated in
# memo to a list of Containers. theano.function handles that
# case by picking only the needed Containers from the list, so
# here we can just delegate to theano.function.
inputs += [(kit, get_storage(kit, not allocate_all)) for kit in self.kits]
return F.function(inputs, outputs, mode)
......@@ -238,8 +383,10 @@ class Method(Component):
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else:
rval = ''
mode = kwargs.pop('mode', None)
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
# If mode is in kwargs, prints the optimized version of the method
mode = kwargs.pop('mode', None)
if mode:
f = self.build(mode, {}, True)
einputs, eoutputs = f.maker.env.inputs, f.maker.env.outputs
......@@ -282,13 +429,21 @@ class Method(Component):
class CompositeInstance(object):
"""
Generic type which various Composite subclasses are intended to
build.
"""
def __init__(self, component, __items__):
# The Component that built this CompositeInstance
self.__dict__['component'] = component
# Some data structure indexable using []
self.__dict__['__items__'] = __items__
def __getitem__(self, item):
x = self.__items__[item]
# For practical reasons, if the item is a Container, we
# return its contents.
if isinstance(x, gof.Container):
return x.value
return x
......@@ -296,14 +451,20 @@ class CompositeInstance(object):
def __setitem__(self, item, value):
x = self.__items__[item]
if isinstance(x, gof.Container):
# If the item is a Container, we set its value
x.value = value
elif hasattr(x, 'initialize'):
# If the item has an initialize() method, we use
# it with the value as argument
x.initialize(value)
else:
##self.__items__[item] = value
raise KeyError('Cannot set item %s' % item)
class Composite(Component):
"""
Composite represents a structure that contains Components.
"""
def resolve(self, name):
raise NotImplementedError
......@@ -321,6 +482,12 @@ class Composite(Component):
raise NotImplementedError
def flat_components(self, include_self = False):
"""
Generator that yields each component in a flattened hierarchy
of composites and components. If include_self is True, the
list will include the Composite instances, else it will only
yield the list of leaves.
"""
if include_self:
yield self
for component in self.components():
......@@ -331,6 +498,15 @@ class Composite(Component):
yield component
def flat_components_map(self, include_self = False, path = []):
"""
Generator that yields (path, component) pairs in a flattened
hierarchy of composites and components, where path is a
sequence of keys such that
component is self[path[0]][path[1]]...
If include_self is True, the list will include the Composite
instances, else it will only yield the list of leaves.
"""
if include_self:
yield path, self
for name, component in self.components_map():
......@@ -342,22 +518,33 @@ class Composite(Component):
yield path2, component
def allocate(self, memo):
"""
Does allocation for each component in the composite.
"""
for member in self.components():
member.allocate(memo)
def get(self, item):
"""
Get the Component associated to the key.
"""
raise NotImplementedError
def set(self, item, value):
"""
Set the Component associated to the key.
"""
raise NotImplementedError
def __getitem__(self, item):
# Uses get() internally
x = self.get(item)
if isinstance(x, (External, Member)):
return x.r
return x
def __setitem__(self, item, value):
# Uses set() internally
self.set(item, value)
def __iter__(self):
......@@ -378,6 +565,10 @@ class ComponentListInstance(CompositeInstance):
self[i] = initv
class ComponentList(Composite):
"""
ComponentList represents a sequence of Component. It builds a
ComponentListInstance.
"""
def __init__(self, *_components):
super(ComponentList, self).__init__()
......@@ -388,6 +579,9 @@ class ComponentList(Composite):
self.append(c)
def resolve(self, name):
# resolves # to the #th number in the list
# resolves name string to parent.resolve(name)
# TODO: eliminate canonicalize
name = canonicalize(name)
try:
item = self.get(name[0])
......@@ -397,6 +591,7 @@ class ComponentList(Composite):
raise TypeError('Cannot resolve a non-integer name on an unbound ComponentList.')
return self.parent.resolve(name)
if len(name) > 1:
# TODO: eliminate
return item.resolve(name[1:])
return item
......@@ -447,6 +642,9 @@ class ComponentList(Composite):
def __str__(self):
return str(self._components)
def __len__(self):
return len(self._components)
def pretty(self, **kwargs):
cr = '\n ' #if header else '\n'
strings = []
......@@ -466,13 +664,18 @@ class ComponentList(Composite):
return self.__class__(*[c.dup() for c in self._components])
class ModuleInstance(CompositeInstance):
class ComponentDictInstance(CompositeInstance):
"""
ComponentDictInstance is meant to be instantiated by ComponentDict.
"""
def __setitem__(self, item, value):
if item not in self.__items__:
# Set it if it's not there
# TODO: is this needed here? move to ModuleInstance?
self.__items__[item] = value
return
super(ModuleInstance, self).__setitem__(item, value)
super(ComponentDictInstance, self).__setitem__(item, value)
def __str__(self):
strings = []
......@@ -485,11 +688,11 @@ class ModuleInstance(CompositeInstance):
return '{%s}' % '\n'.join(strings).replace('\n', '\n ')
class Module(Composite):
InstanceType = ModuleInstance
class ComponentDict(Composite):
InstanceType = ComponentDictInstance # Type used by build() to make the instance
def __init__(self, components = {}, **kwcomponents):
super(Module, self).__init__()
super(ComponentDict, self).__init__()
components = dict(components, **kwcomponents)
self.__dict__['_components'] = components
......@@ -519,7 +722,7 @@ class Module(Composite):
def set(self, item, value):
if not isinstance(value, Component):
raise TypeError('Module may only contain Components.', value, type(value))
raise TypeError('ComponentDict may only contain Components.', value, type(value))
value = value.bind(self, item)
self._components[item] = value
......@@ -527,7 +730,7 @@ class Module(Composite):
cr = '\n ' #if header else '\n'
strings = []
# if header:
# rval += "Module:"
# rval += "ComponentDict:"
for name, component in self.components_map():
if name.startswith('_'):
continue
......@@ -536,10 +739,10 @@ class Module(Composite):
return '\n'.join(strings)
def __str__(self):
return "Module(%s)" % ', '.join(x for x in sorted(map(str, self._components)) if x[0] != '_')
return "ComponentDict(%s)" % ', '.join(x for x in sorted(map(str, self._components)) if x[0] != '_')
def __set_name__(self, name):
super(Module, self).__set_name__(name)
super(ComponentDict, self).__set_name__(name)
for mname, member in self._components.iteritems():
member.name = '%s.%s' % (name, mname)
......@@ -553,6 +756,10 @@ def register_wrapper(condition, wrapper):
__autowrappers.append((condition, wrapper))
def wrap(x):
"""
Wraps x in a Component. Wrappers can be registered using
register_wrapper to allow wrapping more types.
"""
if isinstance(x, Component):
return x
for condition, wrapper in __autowrappers:
......@@ -560,12 +767,15 @@ def wrap(x):
return wrapper(x)
return x
# Result -> External
register_wrapper(lambda x: isinstance(x, gof.Result),
lambda x: External(x))
# [Component1, Component2, ...] -> ComponentList(Component1, Component2, ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x))
# [Result1, Result2, ...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x),
lambda x: ComponentList(*map(Member, x)))
......@@ -586,7 +796,14 @@ class Curry:
self.meth = getattr(self.obj, self.name)
class FancyModuleInstance(ModuleInstance):
class ModuleInstance(ComponentDictInstance):
"""
ModuleInstance is meant to be instantiated by Module. This differs
from ComponentDictInstance on a key point, which is that getattr
does a similar thing to getitem.
ModuleInstance is compatible for use as ComponentDict.InstanceType.
"""
def __getattr__(self, attr):
if attr == '__items__' and '__items__' not in self.__dict__:
......@@ -602,10 +819,14 @@ class FancyModuleInstance(ModuleInstance):
except KeyError:
self.__dict__[attr] = value
class FancyModule(Module):
InstanceType = FancyModuleInstance
class Module(ComponentDict):
InstanceType = ModuleInstance # By default, we use build ModuleInstance
def __wrapper__(self, x):
"""
This function is called whenever x is set as an attribute of
the Module.
"""
return wrap(x)
def __getattr__(self, attr):
......@@ -616,6 +837,8 @@ class FancyModule(Module):
except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
if isinstance(rval, (External, Member)):
# Special treatment for External and Member, so that
# the user may use them to build graphs more easily.
return rval.r
return rval
......@@ -637,25 +860,40 @@ class FancyModule(Module):
self.__dict__[attr] = value
def build(self, mode, memo):
inst = super(FancyModule, self).build(mode, memo)
inst = super(Module, self).build(mode, memo)
for method in dir(self):
# Any method with a name like '_instance_XXX' is added to
# the object built under the name obj.XXX
if method.startswith('_instance_'):
setattr(inst, method[10:], Curry(self, method, inst))
return inst
def _instance_initialize(self, inst, init = {}, **kwinit):
"""
Default initialization method.
"""
for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value
FancyModule = Module
FancyModuleInstance = ModuleInstance
class KitComponent(Component):
"""
Represents a SymbolicInputKit (see io.py).
"""
def __init__(self, kit):
super(KitComponent, self).__init__()
self.kit = kit
def allocate(self, memo):
"""
Allocates a Container for each input in the kit. Sets a key in
the memo that maps the SymbolicInputKit to the list of
Containers.
"""
kit = self.kit
if kit in memo:
return memo[kit]
......
......@@ -246,6 +246,7 @@ class LocalLinker(Linker):
class PerformLinker(LocalLinker):
"""WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{Env} in the order given by L{Env.toposort}.
"""
......@@ -254,6 +255,13 @@ class PerformLinker(LocalLinker):
self.env = None
def accept(self, env, no_recycling = []):
"""
:param env: a PerformLinker can have accepted one Env instance at a time.
:param no_recycling: WRITEME
:returns: self (WHY? Who calls this function?)
"""
if self.env is not None and self.env is not env:
return type(self)().accept(env, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another Env.")
......@@ -262,6 +270,14 @@ class PerformLinker(LocalLinker):
return self
def make_all(self, profiler = None, input_storage = None, output_storage = None):
"""
:param profiler: WRITEME
:param input_storage: WRITEME
:param output_storage: WRITEME
:returns: WRITEME (or see: SOMETHING)
"""
env = self.env
order = env.toposort()
no_recycling = self.no_recycling
......
"""Pretty-printing graphs, and the 'Print' Op.
"""
import gof
from copy import copy
import sys
from gof import Op, Apply
class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs
when it runs.
"""
def __init__(self,message=""):
self.message=message
self.view_map={0:[0]}
def make_node(self,xin):
xout = xin.type.make_result()
return Apply(op = self, inputs = [xin], outputs=[xout])
def perform(self,node,inputs,output_storage):
xin, = inputs
xout, = output_storage
xout[0] = xin
print self.message,xin
def grad(self,input,output_gradients):
return output_gradients
class PrinterState(gof.utils.scratchpad):
......@@ -232,3 +255,4 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint
......@@ -21,7 +21,7 @@ from .. import scalar as scal
from ..gof.python25 import partial
from .. import compile, printing
from ..printing import pprint
from ..printing import pprint, Print
### set up the external interface
......@@ -456,10 +456,11 @@ class _tensor_py_operators:
def __abs__(self): return abs_(self)
def __neg__(self): return neg(self)
#CASTS
def __int__(self): return AsInt(self).out
def __float__(self): return AsInt(self).out
def __complex__(self): return AsComplex(self).out
#CASTS
#### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return an int. -JB 20081112
#def __int__(self): return convert_to_int32(self)
#def __float__(self): return convert_to_float64(self)
#def __complex__(self): return convert_to_complex128(self)
#COMPARISONS
def __lt__(self,other): return lt(self, other)
......@@ -712,7 +713,7 @@ class Shape(Op):
x = as_tensor(x)
return Apply(self, [x], [lvector()])
def perform(self, node, (x, ), (out, )):
out[0] = numpy.asarray(x.shape)
out[0] = numpy.asarray(x.shape, dtype = 'int64')
def grad(self, (x,), (gz,)):
return [None]
@_redefine_asRoutine(Shape())
......@@ -1012,6 +1013,10 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def mean(input, axis = None):
"""WRITEME"""
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
input = convert_to_float64(input)
s = sum(input, axis)
shp = shape(input)
if axis is None:
......@@ -1589,7 +1594,7 @@ def concatenate(tensor_list, axis=0):
if not isinstance(tensor_list, (tuple, list)):
raise TypeError("The 'tensors' argument must be either a tuple "
"or a list, make sure you did not forget () or [] around "
"arguments of concatenate.", tensors)
"arguments of concatenate.", tensor_list)
return join(axis, *tensor_list)
def get_vector_length(v):
......
......@@ -55,8 +55,9 @@ class RandomFunction(gof.Op):
r = copy(r)
rout[0] = r
rval = self.fn(r, *(args + [shape]))
if not isinstance(rval, numpy.ndarray):
out[0] = numpy.asarray(rval, dtype = node.outputs[0].type.dtype)
if not isinstance(rval, numpy.ndarray) \
or str(rval.dtype) != node.outputs[1].type.dtype:
out[0] = numpy.asarray(rval, dtype = node.outputs[1].type.dtype)
else:
out[0] = rval
......@@ -237,7 +238,7 @@ class RandomKit(SymbolicInputKit):
rk = RandomKit('rk', 0xBAD5EED)
class RModule(compile.FancyModule):
class RModule(compile.Module):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
......
from xlogx import xlogx
import unittest
from theano import compile
from theano import gradient
from theano.tensor import as_tensor
import theano._test_tensor as TT
import random
import numpy.random
class T_XlogX(unittest.TestCase):
def test0(self):
x = as_tensor([1, 0])
y = xlogx(x)
y = compile.eval_outputs([y])
self.failUnless(numpy.all(y == numpy.asarray([0, 0.])))
def test1(self):
class Dummy(object):
def make_node(self, a):
return [xlogx(a)[:,2]]
TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)])
if __name__ == '__main__':
unittest.main()
import theano
from theano import tensor, scalar
import numpy
class XlogX(scalar.UnaryScalarOp):
"""
Compute X * log(X), with special case 0 log(0) = 0.
"""
@staticmethod
def st_impl(x):
if x == 0.0:
return 0.0
return x * numpy.log(x)
def impl(self, x):
return XlogX.st_impl(x)
def grad(self, (x,), (gz,)):
return [gz * (1 + scalar.log(x))]
def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in [scalar.float32, scalar.float64]:
return """%(z)s =
%(x)s == 0.0
? 0.0
: %(x)s * log(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论