fixed circular imports with xlogx

上级 68d1441b
...@@ -5,6 +5,7 @@ from basic import * ...@@ -5,6 +5,7 @@ from basic import *
import opt import opt
import blas import blas
import xlogx
import raw_random, randomstreams import raw_random, randomstreams
from randomstreams import \ from randomstreams import \
......
import theano import theano
from theano import tensor, scalar
import numpy import numpy
from elemwise import Elemwise
from theano import scalar
class XlogX(scalar.UnaryScalarOp): class XlogX(scalar.UnaryScalarOp):
""" """
Compute X * log(X), with special case 0 log(0) = 0. Compute X * log(X), with special case 0 log(0) = 0.
...@@ -24,7 +26,7 @@ class XlogX(scalar.UnaryScalarOp): ...@@ -24,7 +26,7 @@ class XlogX(scalar.UnaryScalarOp):
: %(x)s * log(%(x)s);""" % locals() : %(x)s * log(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx') scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx') xlogx = Elemwise(scalar_xlogx, name='xlogx')
class XlogY0(scalar.BinaryScalarOp): class XlogY0(scalar.BinaryScalarOp):
...@@ -48,5 +50,4 @@ class XlogY0(scalar.BinaryScalarOp): ...@@ -48,5 +50,4 @@ class XlogY0(scalar.BinaryScalarOp):
: %(x)s * log(%(y)s);""" % locals() : %(x)s * log(%(y)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0') scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0')
xlogy0 = tensor.Elemwise(scalar_xlogy0, name='xlogy0') xlogy0 = Elemwise(scalar_xlogy0, name='xlogy0')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论