提交 0ae95f31 authored 作者: Marc-Antoine Rondeau's avatar Marc-Antoine Rondeau

Added support for __divmod__ and __rdivmod__, with a basic unit test.

上级 96a2c0b6
...@@ -2905,6 +2905,9 @@ def div_proxy(x, y): ...@@ -2905,6 +2905,9 @@ def div_proxy(x, y):
as_tensor_variable(y).dtype in discrete_dtypes)) as_tensor_variable(y).dtype in discrete_dtypes))
return f(x, y) return f(x, y)
def divmod(x, y):
"""elementvise divmod, using floor_div and mod_check"""
return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1) @_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms): def add(a, *other_terms):
......
...@@ -6113,7 +6113,7 @@ def test_len(): ...@@ -6113,7 +6113,7 @@ def test_len():
def test_mod(): def test_mod():
""" """
We add this test as not all language and C implementation give the same We add this test as not all language and C implementation give the same
signe to the result. This check that the c_code of `Mod` is implemented sign to the result. This check that the c_code of `Mod` is implemented
as Python. That is what we want. as Python. That is what we want.
""" """
x, y = fscalars('xy') x, y = fscalars('xy')
...@@ -6126,6 +6126,23 @@ def test_mod(): ...@@ -6126,6 +6126,23 @@ def test_mod():
assert fn(a, b) == a % b, (a,) assert fn(a, b) == a % b, (a,)
def test_divmod():
"""
Confirm that divmod is equivalent to the python version.
"""
x, y = fscalars('xy')
d, r = divmod(x, y)
fn = gof.DualLinker().accept(
gof.FunctionGraph([x, y], [d, r])).make_function()
for a, b in ((0, 1), (1, 1), (0, -1), (1, -1), (-1, -1),
(1, 2), (-1, 2), (1, -2), (-1, -2),
(5, 3), (-5, 3), (5, -3), (-5, -3)
):
d_v, r_v = fn(a, b)
d_vp, r_vp = divmod(a, b)
assert d_v == d_vp and r_v == r_vp, (a,)
def test_mod_compile(): def test_mod_compile():
""" """
This test generate an Elemwise of Composite as: This test generate an Elemwise of Composite as:
......
...@@ -190,6 +190,9 @@ class _tensor_py_operators: ...@@ -190,6 +190,9 @@ class _tensor_py_operators:
except (NotImplementedError, AsTensorError): except (NotImplementedError, AsTensorError):
return NotImplemented return NotImplemented
def __divmod__(self, other):
return theano.tensor.basic.divmod(self, other)
def __truediv__(self, other): def __truediv__(self, other):
return theano.tensor.basic.true_div(self, other) return theano.tensor.basic.true_div(self, other)
...@@ -235,6 +238,10 @@ class _tensor_py_operators: ...@@ -235,6 +238,10 @@ class _tensor_py_operators:
def __rmod__(self, other): def __rmod__(self, other):
return theano.tensor.basic.mod(other, self) return theano.tensor.basic.mod(other, self)
def __rdivmod__(self, other):
return theano.tensor.basic.divmod(other, self)
def __rpow__(self, other): def __rpow__(self, other):
return theano.tensor.basic.pow(other, self) return theano.tensor.basic.pow(other, self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论