fixed circular imports with xlogx

上级 68d1441b
......@@ -5,6 +5,7 @@ from basic import *
import opt
import blas
import xlogx
import raw_random, randomstreams
from randomstreams import \
......
import theano
from theano import tensor, scalar
import numpy
from elemwise import Elemwise
from theano import scalar
class XlogX(scalar.UnaryScalarOp):
"""
Compute X * log(X), with special case 0 log(0) = 0.
......@@ -24,7 +26,7 @@ class XlogX(scalar.UnaryScalarOp):
: %(x)s * log(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx')
xlogx = Elemwise(scalar_xlogx, name='xlogx')
class XlogY0(scalar.BinaryScalarOp):
......@@ -48,5 +50,4 @@ class XlogY0(scalar.BinaryScalarOp):
: %(x)s * log(%(y)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
scalar_xlogy0 = XlogY0(scalar.upgrade_to_float, name='scalar_xlogy0')
xlogy0 = tensor.Elemwise(scalar_xlogy0, name='xlogy0')
xlogy0 = Elemwise(scalar_xlogy0, name='xlogy0')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论