提交 a2cf238f authored 作者: Ramana.S's avatar Ramana.S

ScalarVariable compliant with future divisions

上级 6c02a13d
import unittest
from __future__ import absolute_import, print_function, division
import unittest
import os
import re
......@@ -12,5 +13,5 @@ class FunctionName(unittest.TestCase):
x = tensor.vector('x')
func = theano.function([x], x + 1.)
regex = re.compile(os.path.basename('.*test_function_name.pyc?:13'))
regex = re.compile(os.path.basename('.*test_function_name.pyc?:14'))
assert(regex.match(func.name) is not None)
......@@ -21,7 +21,7 @@ import numpy
from six.moves import xrange
import theano
from theano.compat import PY3, imap, izip
from theano.compat import imap, izip
from theano import gof, printing
from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph)
......@@ -604,12 +604,11 @@ class _scalar_py_operators:
def __mul__(self, other):
return mul(self, other)
if PY3:
def __truediv__(self, other):
return div_proxy(self, other)
else:
def __div__(self, other):
return div_proxy(self, other)
def __truediv__(self, other):
return div_proxy(self, other)
def __div__(self, other):
return div_proxy(self, other)
def __floordiv__(self, other):
return int_div(self, other)
......
......@@ -9,7 +9,8 @@ If you do want to rewrite these tests, bear in mind:
* FunctionGraph and DualLinker are old, use compile.function instead.
"""
from __future__ import division
from __future__ import absolute_import, print_function, division
import unittest
import numpy as np
......
from __future__ import print_function, division
from __future__ import absolute_import, print_function, division
import os
import shutil
import sys
......
from __future__ import division
from __future__ import absolute_import, print_function, division
import numpy
from theano.tensor.elemwise import Elemwise
from theano import scalar
......@@ -54,7 +56,7 @@ class XlogY0(scalar.BinaryScalarOp):
def grad(self, inputs, grads):
x, y = inputs
gz, = grads
return [gz * scalar.log(y), gz * x/y]
return [gz * scalar.log(y), gz * x / y]
def c_code(self, node, name, inputs, outputs, sub):
x, y = inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论