提交 3ed76a56 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

tests for OpFromGraph

上级 7a7e27a6
...@@ -9,7 +9,7 @@ from gof import \ ...@@ -9,7 +9,7 @@ from gof import \
Type, Generic, generic, \ Type, Generic, generic, \
object2, utils object2, utils
from compile import function, eval_outputs, fast_compute from compile import function, eval_outputs, fast_compute, OpFromGraph
import tensor import tensor
import tensor_random import tensor_random
......
...@@ -127,7 +127,36 @@ class T_fast_compute(unittest.TestCase): ...@@ -127,7 +127,36 @@ class T_fast_compute(unittest.TestCase):
e = x*x + y*y + z*z e = x*x + y*y + z*z
assert fast_compute(e) == 14.0 assert fast_compute(e) == 14.0
assert compile._fcache[(e, )]() == 14.0 assert compile._fcache[(e, )]() == 14.0
import tensor as T
import numpy as N
class T_OpFromGraph(unittest.TestCase):
def test_straightforward(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], linker='c|py')
f = op(x, y, z) - op(y, z, x)
fn = function([x, y, z], [f])
xv, yv, zv = N.ones((2, 2)), N.ones((2, 2))*3, N.ones((2, 2))*5
assert numpy.all(8.0 == fn(xv, yv, zv))
assert numpy.all(8.0 == fn(xv, yv, zv))
def test_size_changes(self):
x, y, z = T.matrices('xyz')
e = T.dot(x, y)
op = OpFromGraph([x, y], [e], linker='c|py')
f = op(x, op(y, z))
fn = function([x, y, z], [f])
xv, yv, zv = N.ones((2, 3)), N.ones((3, 4))*3, N.ones((4, 5))*5
res = fn(xv, yv, zv)
assert res.shape == (2, 5)
assert numpy.all(180.0 == res)
res = fn(xv, yv, zv)
assert res.shape == (2, 5)
assert numpy.all(180.0 == res)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -188,13 +188,15 @@ class OpFromGraph(gof.Op): ...@@ -188,13 +188,15 @@ class OpFromGraph(gof.Op):
function and the resulting Op's perform will do the same operation as function and the resulting Op's perform will do the same operation as
function(inputs, outputs, **kwargs) function(inputs, outputs, **kwargs)
Take note that the following arguments will be forcefully set to Take note that the following options, if provided, must take the
a particular value: value(s) listed below:
unpack_single = False unpack_single = False
borrow_outputs = False borrow_outputs = False
""" """
def __init__(self, inputs, outputs, **kwargs): def __init__(self, inputs, outputs, **kwargs):
if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
raise ValueError('The borrow_outputs and unpack_single options cannot be True')
kwargs['unpack_single'] = False kwargs['unpack_single'] = False
kwargs['borrow_outputs'] = False kwargs['borrow_outputs'] = False
self.fn = function(inputs, outputs, **kwargs) self.fn = function(inputs, outputs, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论