Renamed ResultBase to Result

上级 0adda1ff
...@@ -3,11 +3,11 @@ import gof, gof.modes, gof.opt ...@@ -3,11 +3,11 @@ import gof, gof.modes, gof.opt
from compile import * from compile import *
class Double(gof.result.ResultBase): class Double(gof.result.Result):
def __init__(self, data, name = "oignon"): def __init__(self, data, name = "oignon"):
assert isinstance(data, float) assert isinstance(data, float)
gof.result.ResultBase.__init__(self, role = None, name = name) gof.result.Result.__init__(self, role = None, name = name)
self.data = data self.data = data
def __str__(self): def __str__(self):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import time import time
import unittest import unittest
from gof import ResultBase, Op, Env, modes from gof import Result, Op, Env, modes
import gof import gof
from scalar import * from scalar import *
......
...@@ -14,8 +14,8 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -14,8 +14,8 @@ class _test_grad_sources_inputs(unittest.TestCase):
"""Test that it is not ok to return None from op.grad()""" """Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
def __init__(self, arg): def __init__(self, arg):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.Result()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
pass pass
a = retNone(5) a = retNone(5)
...@@ -30,10 +30,10 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -30,10 +30,10 @@ class _test_grad_sources_inputs(unittest.TestCase):
class retNone(gof.op.Op): class retNone(gof.op.Op):
def __init__(self, arg): def __init__(self, arg):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return [None] return [None]
i = gof.result.ResultBase() i = gof.result.Result()
a = retNone([i]) a = retNone([i])
g = grad_sources_inputs([(a.out, 1)], None) g = grad_sources_inputs([(a.out, 1)], None)
self.failUnless(not i in g) self.failUnless(not i in g)
...@@ -43,12 +43,12 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -43,12 +43,12 @@ class _test_grad_sources_inputs(unittest.TestCase):
class retNone(gof.op.Op): class retNone(gof.op.Op):
def __init__(self, arg): def __init__(self, arg):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return [None] return [None]
i = gof.result.ResultBase() i = gof.result.Result()
j = gof.result.ResultBase() j = gof.result.Result()
a1 = retNone([i]) a1 = retNone([i])
g = grad_sources_inputs([(a1.out, 1)], None) g = grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone([i,j]) a2 = retNone([i,j])
...@@ -65,22 +65,22 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -65,22 +65,22 @@ class _test_grad_sources_inputs(unittest.TestCase):
class retNone(gof.op.Op): class retNone(gof.op.Op):
def __init__(self, arg, tst): def __init__(self, arg, tst):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
self.tst = tst self.tst = tst
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
self.tst.fail() self.tst.fail()
i = gof.result.ResultBase() i = gof.result.Result()
a1 = retNone([i],self) a1 = retNone([i],self)
g = grad_sources_inputs([(a1.out, None)], None) g = grad_sources_inputs([(a1.out, None)], None)
def test_1in_1out(self): def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op""" """Test grad is called correctly for a 1-to-1 op"""
gval = gof.result.ResultBase() gval = gof.result.Result()
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.Result()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gval, return gval,
a1 = O() a1 = O()
...@@ -89,11 +89,11 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -89,11 +89,11 @@ class _test_grad_sources_inputs(unittest.TestCase):
def test_1in_Nout(self): def test_1in_Nout(self):
"""Test grad is called correctly for a 1-to-many op""" """Test grad is called correctly for a 1-to-many op"""
gval = gof.result.ResultBase() gval = gof.result.Result()
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.Result()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
def grad(self, (x, ), (gz1, gz2)): def grad(self, (x, ), (gz1, gz2)):
return gval, return gval,
a1 = O() a1 = O()
...@@ -101,12 +101,12 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -101,12 +101,12 @@ class _test_grad_sources_inputs(unittest.TestCase):
self.failUnless(g[a1.inputs[0]] is gval) self.failUnless(g[a1.inputs[0]] is gval)
def test_Nin_1out(self): def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op""" """Test grad is called correctly for a many-to-1 op"""
gval0 = gof.result.ResultBase() gval0 = gof.result.Result()
gval1 = gof.result.ResultBase() gval1 = gof.result.Result()
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.inputs = [gof.result.Result(),gof.result.Result()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.Result()]
def grad(self, (x0,x1), (gz, )): def grad(self, (x0,x1), (gz, )):
return (gval0, gval1) return (gval0, gval1)
a1 = O() a1 = O()
...@@ -115,12 +115,12 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -115,12 +115,12 @@ class _test_grad_sources_inputs(unittest.TestCase):
self.failUnless(g[a1.inputs[1]] is gval1) self.failUnless(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self): def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op""" """Test grad is called correctly for a many-to-many op"""
gval0 = gof.result.ResultBase() gval0 = gof.result.Result()
gval1 = gof.result.ResultBase() gval1 = gof.result.Result()
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.inputs = [gof.result.Result(),gof.result.Result()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
def grad(self, (x0,x1), (gz0,gz1)): def grad(self, (x0,x1), (gz0,gz1)):
return gval0, gval1 return gval0, gval1
a1 = O() a1 = O()
...@@ -132,11 +132,11 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -132,11 +132,11 @@ class _test_grad_sources_inputs(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, arg, tst): def __init__(self, arg, tst):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
self.tst = tst self.tst = tst
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
return [1] return [1]
i = gof.result.ResultBase() i = gof.result.Result()
a1 = O([i],self) a1 = O([i],self)
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[i] is 1) self.failUnless(g[i] is 1)
...@@ -146,7 +146,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -146,7 +146,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok): def __init__(self, arg, tst, grad_ok):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
...@@ -154,9 +154,9 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -154,9 +154,9 @@ class _test_grad_sources_inputs(unittest.TestCase):
self.tst.fail() self.tst.fail()
else: else:
return [1, None] return [1, None]
i = gof.result.ResultBase() i = gof.result.Result()
j = gof.result.ResultBase() j = gof.result.Result()
k = gof.result.ResultBase() k = gof.result.Result()
a1 = O([i,j],self,True) a1 = O([i,j],self,True)
a2 = O([a1.outputs[1], k], self, True) a2 = O([a1.outputs[1], k], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1)], None) g = grad_sources_inputs([(a2.outputs[0], 1)], None)
...@@ -172,7 +172,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -172,7 +172,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok): def __init__(self, arg, tst, grad_ok):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def grad(self, inputs, (g0,g1)): def grad(self, inputs, (g0,g1)):
...@@ -183,9 +183,9 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -183,9 +183,9 @@ class _test_grad_sources_inputs(unittest.TestCase):
return [g0, g0+g1] return [g0, g0+g1]
else: else:
return [g0, g0] return [g0, g0]
i = gof.result.ResultBase() i = gof.result.Result()
j = gof.result.ResultBase() j = gof.result.Result()
k = gof.result.ResultBase() k = gof.result.Result()
a1 = O([i,j],self,True) a1 = O([i,j],self,True)
a2 = O([k,a1.outputs[1]], self, True) a2 = O([k,a1.outputs[1]], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
...@@ -202,7 +202,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -202,7 +202,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok): def __init__(self, arg, tst, grad_ok):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def grad(self, inputs, (g0,g1)): def grad(self, inputs, (g0,g1)):
...@@ -213,9 +213,9 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -213,9 +213,9 @@ class _test_grad_sources_inputs(unittest.TestCase):
return [g0, g0+g1] return [g0, g0+g1]
else: else:
return [g0, g0] return [g0, g0]
i = gof.result.ResultBase() i = gof.result.Result()
j = gof.result.ResultBase() j = gof.result.Result()
k = gof.result.ResultBase() k = gof.result.Result()
a1 = O([i,j],self,True) a1 = O([i,j],self,True)
a2 = O([k,a1.outputs[1]], self, True) a2 = O([k,a1.outputs[1]], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
...@@ -231,10 +231,10 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -231,10 +231,10 @@ class _test_grad_sources_inputs(unittest.TestCase):
class _test_grad(unittest.TestCase): class _test_grad(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.inputs = [gof.result.Result(),gof.result.Result()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.Result(),gof.result.Result()]
self.gval0 = gof.result.ResultBase() self.gval0 = gof.result.Result()
self.gval1 = gof.result.ResultBase() self.gval1 = gof.result.Result()
def grad(self, (x0,x1), (gz0,gz1)): def grad(self, (x0,x1), (gz0,gz1)):
return self.gval0, self.gval1 return self.gval0, self.gval1
......
import unittest import unittest
from gof import ResultBase, Op, Env, modes from gof import Result, Op, Env, modes
import gof import gof
from scalar import * from scalar import *
......
...@@ -104,7 +104,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -104,7 +104,7 @@ class _testCase_dot(unittest.TestCase):
m = mtype(a) m = mtype(a)
ab = m.dot(b) ab = m.dot(b)
try: try:
z = dot(SparseR(m),core.ResultBase(data=b)) z = dot(SparseR(m),core.Result(data=b))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
except Exception, e: except Exception, e:
...@@ -121,7 +121,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -121,7 +121,7 @@ class _testCase_dot(unittest.TestCase):
sparse.lil_matrix]:#, sparse.coo_matrix]: sparse.lil_matrix]:#, sparse.coo_matrix]:
m = mtype(b) m = mtype(b)
ab = m.transpose().dot(a.transpose()).transpose() ab = m.transpose().dot(a.transpose()).transpose()
z = dot(core.ResultBase(data=a),SparseR(mtype(b))) z = dot(core.Result(data=a),SparseR(mtype(b)))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
......
"""A simple class to store ndarray data """ """A simple class to store ndarray data """
from gof import ResultBase, Op, utils, AbstractFunctionError from gof import Result, Op, utils, AbstractFunctionError
import numpy import numpy
from copy import copy from copy import copy
...@@ -9,8 +9,8 @@ from copy import copy ...@@ -9,8 +9,8 @@ from copy import copy
# BaseTensor Class # BaseTensor Class
########################### ###########################
class BaseTensor(ResultBase): class BaseTensor(Result):
"""ResultBase to store numpy.ndarray or equivalent via .data """Result to store numpy.ndarray or equivalent via .data
Attributes: Attributes:
_dtype - numpy dtype string such as 'int64' or 'float64' (among others) _dtype - numpy dtype string such as 'int64' or 'float64' (among others)
...@@ -43,13 +43,13 @@ class BaseTensor(ResultBase): ...@@ -43,13 +43,13 @@ class BaseTensor(ResultBase):
# the argument that is awkward to construct, I decided to put all this # the argument that is awkward to construct, I decided to put all this
# into the tensor(data,...) function below, which is like a second # into the tensor(data,...) function below, which is like a second
# constructor that works with an ndarray. # constructor that works with an ndarray.
ResultBase.__init__(self, role=role, name=name) Result.__init__(self, role=role, name=name)
self._dtype = str(dtype) self._dtype = str(dtype)
self.dtype_specs() # this is just for error checking self.dtype_specs() # this is just for error checking
self._broadcastable = tuple(broadcastable) self._broadcastable = tuple(broadcastable)
###################### ######################
# ResultBase interface # Result interface
###################### ######################
# #
......
...@@ -3,14 +3,14 @@ import unittest ...@@ -3,14 +3,14 @@ import unittest
from link import PerformLinker, Profiler from link import PerformLinker, Profiler
from cc import * from cc import *
from result import ResultBase from result import Result
from op import Op from op import Op
from env import Env from env import Env
class Double(ResultBase): class Double(Result):
def __init__(self, data, name = "oignon"): def __init__(self, data, name = "oignon"):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
assert isinstance(data, float) assert isinstance(data, float)
self.data = data self.data = data
......
import unittest import unittest
from result import ResultBase from result import Result
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
......
...@@ -4,14 +4,14 @@ import unittest ...@@ -4,14 +4,14 @@ import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import ResultBase from result import Result
class MyResult(ResultBase): class MyResult(Result):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
ResultBase.__init__(self, role = None ) Result.__init__(self, role = None )
self.data = [self.thingy] self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
from result import ResultBase from result import Result
from op import Op from op import Op
from env import Env from env import Env
......
...@@ -4,15 +4,15 @@ import unittest ...@@ -4,15 +4,15 @@ import unittest
from modes import * from modes import *
from result import ResultBase from result import Result
from op import Op from op import Op
from env import Env from env import Env
class Double(ResultBase): class Double(Result):
def __init__(self, data, name = "oignon"): def __init__(self, data, name = "oignon"):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
assert isinstance(data, float) assert isinstance(data, float)
self.data = data self.data = data
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import ResultBase from result import Result
class MyResult(ResultBase): class MyResult(Result):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
ResultBase.__init__(self, role = None) Result.__init__(self, role = None)
self.data = [self.thingy] self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
...@@ -48,7 +48,7 @@ class _test_Op(unittest.TestCase): ...@@ -48,7 +48,7 @@ class _test_Op(unittest.TestCase):
# validate_update # validate_update
def test_validate_update(self): def test_validate_update(self):
try: try:
MyOp(ResultBase(), MyResult(1)) # MyOp requires MyResult instances MyOp(Result(), MyResult(1)) # MyOp requires MyResult instances
except Exception, e: except Exception, e:
assert str(e) == "Error 1" assert str(e) == "Error 1"
else: else:
......
import unittest import unittest
from result import ResultBase from result import Result
from op import Op from op import Op
from ext import Destroyer from ext import Destroyer
from opt import * from opt import *
...@@ -9,10 +9,10 @@ from env import Env ...@@ -9,10 +9,10 @@ from env import Env
from toolbox import * from toolbox import *
class MyResult(ResultBase): class MyResult(Result):
def __init__(self, name): def __init__(self, name):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
self.data = [1000] self.data = [1000]
def __str__(self): def __str__(self):
......
...@@ -3,10 +3,10 @@ import unittest ...@@ -3,10 +3,10 @@ import unittest
from result import * from result import *
class Double(ResultBase): class Double(Result):
def __init__(self, data, name = "oignon"): def __init__(self, data, name = "oignon"):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
assert isinstance(data, float) assert isinstance(data, float)
self.data = data self.data = data
...@@ -19,10 +19,10 @@ class Double(ResultBase): ...@@ -19,10 +19,10 @@ class Double(ResultBase):
def __copy__(self): def __copy__(self):
return Double(self.data, self.name) return Double(self.data, self.name)
class MyResult(ResultBase): class MyResult(Result):
def __init__(self, name): def __init__(self, name):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
self.data = [1000] self.data = [1000]
def __str__(self): def __str__(self):
...@@ -35,12 +35,12 @@ class MyResult(ResultBase): ...@@ -35,12 +35,12 @@ class MyResult(ResultBase):
return MyResult(self.name) return MyResult(self.name)
class _test_ResultBase(unittest.TestCase): class _test_Result(unittest.TestCase):
def test_trivial(self): def test_trivial(self):
r = ResultBase() r = Result()
def test_state(self): def test_state(self):
r = ResultBase() r = Result()
assert r.state is Empty assert r.state is Empty
r.data = 0 r.data = 0
......
import unittest import unittest
from result import ResultBase from result import Result
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
...@@ -10,10 +10,10 @@ from toolbox import * ...@@ -10,10 +10,10 @@ from toolbox import *
class MyResult(ResultBase): class MyResult(Result):
def __init__(self, name): def __init__(self, name):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
self.data = [1000] self.data = [1000]
def __str__(self): def __str__(self):
......
from op import Op from op import Op
from result import ResultBase from result import Result
from env import InconsistencyError from env import InconsistencyError
import utils import utils
import unify import unify
...@@ -306,9 +306,9 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -306,9 +306,9 @@ class PatternOptimizer(OpSpecificOptimizer):
return False return False
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, ResultBase) \ elif isinstance(pattern, Result) \
and getattr(pattern, 'constant', False) \ and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \ and isinstance(expr, Result) \
and getattr(expr, 'constant', False) \ and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc(): and pattern.desc() == expr.desc():
return u return u
...@@ -415,9 +415,9 @@ class PatternDescOptimizer(LocalOptimizer): ...@@ -415,9 +415,9 @@ class PatternDescOptimizer(LocalOptimizer):
return False return False
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, ResultBase) \ elif isinstance(pattern, Result) \
and getattr(pattern, 'constant', False) \ and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \ and isinstance(expr, Result) \
and getattr(expr, 'constant', False) \ and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc(): and pattern.desc() == expr.desc():
return u return u
......
...@@ -9,7 +9,7 @@ import utils ...@@ -9,7 +9,7 @@ import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['ResultBase', __all__ = ['Result',
'PythonResult', 'PythonResult',
'StateError', 'StateError',
'Empty', 'Empty',
...@@ -24,7 +24,7 @@ class StateError(Exception): ...@@ -24,7 +24,7 @@ class StateError(Exception):
"""The state of the Result is a problem""" """The state of the Result is a problem"""
# ResultBase state keywords # Result state keywords
class Empty : """Memory has not been allocated""" class Empty : """Memory has not been allocated"""
class Allocated: """Memory has been allocated, contents are not the owner's output.""" class Allocated: """Memory has been allocated, contents are not the owner's output."""
class Computed : """Memory has been allocated, contents are the owner's output.""" class Computed : """Memory has been allocated, contents are the owner's output."""
...@@ -34,7 +34,7 @@ class Computed : """Memory has been allocated, contents are the owner's output." ...@@ -34,7 +34,7 @@ class Computed : """Memory has been allocated, contents are the owner's output."
# Result # Result
############################ ############################
class ResultBase(object): class Result(object):
"""Base class for storing Op inputs and outputs """Base class for storing Op inputs and outputs
Attributes: Attributes:
...@@ -299,7 +299,7 @@ class ResultBase(object): ...@@ -299,7 +299,7 @@ class ResultBase(object):
raise AbstractFunctionError() raise AbstractFunctionError()
class PythonResult(ResultBase): class PythonResult(Result):
""" """
Represents a generic Python object. The object is available Represents a generic Python object. The object is available
through %(name)s. through %(name)s.
......
...@@ -13,7 +13,7 @@ def _unpack_result(lst): ...@@ -13,7 +13,7 @@ def _unpack_result(lst):
return lst[0] return lst[0]
def _pack_result(arg): def _pack_result(arg):
if isinstance(arg, gof.result.ResultBase): if isinstance(arg, gof.result.Result):
return [arg] return [arg]
else: else:
return arg return arg
......
...@@ -5,7 +5,7 @@ import math ...@@ -5,7 +5,7 @@ import math
from copy import copy from copy import copy
import inspect import inspect
from gof import ResultBase, GuardedOp, utils from gof import Result, GuardedOp, utils
def as_scalar(x, name = None): def as_scalar(x, name = None):
...@@ -21,10 +21,10 @@ def as_scalar(x, name = None): ...@@ -21,10 +21,10 @@ def as_scalar(x, name = None):
return x return x
class Scalar(ResultBase): class Scalar(Result):
def __init__(self, dtype, name = None): def __init__(self, dtype, name = None):
ResultBase.__init__(self, role = None, name = name) Result.__init__(self, role = None, name = name)
self.dtype = dtype self.dtype = dtype
self.dtype_specs() self.dtype_specs()
......
...@@ -22,7 +22,7 @@ def assparse(sp, **kwargs): ...@@ -22,7 +22,7 @@ def assparse(sp, **kwargs):
rval.data = sp rval.data = sp
return rval return rval
class SparseR(gof.result.ResultBase): class SparseR(gof.result.Result):
""" """
Attribute: Attribute:
format - a string identifying the type of sparsity format - a string identifying the type of sparsity
...@@ -49,7 +49,7 @@ class SparseR(gof.result.ResultBase): ...@@ -49,7 +49,7 @@ class SparseR(gof.result.ResultBase):
@return An empty SparseR instance. @return An empty SparseR instance.
""" """
gof.ResultBase.__init__(self, **kwargs) gof.Result.__init__(self, **kwargs)
if dtype in SparseR.dtype_set: if dtype in SparseR.dtype_set:
self._dtype = dtype self._dtype = dtype
assert isinstance(format, str) assert isinstance(format, str)
...@@ -109,7 +109,7 @@ dense_from_sparse = gof.op.constructor(DenseFromSparse) ...@@ -109,7 +109,7 @@ dense_from_sparse = gof.op.constructor(DenseFromSparse)
class SparseFromDense(gof.op.Op): class SparseFromDense(gof.op.Op):
def __init__(self, x, format, **kwargs): def __init__(self, x, format, **kwargs):
gof.op.Op.__init__(self, **kwargs) gof.op.Op.__init__(self, **kwargs)
if isinstance(format, gof.result.ResultBase): if isinstance(format, gof.result.Result):
self.inputs = [tensor.astensor(x), format] self.inputs = [tensor.astensor(x), format]
else: else:
self.inputs = [tensor.astensor(x), gof.result.PythonResult()] self.inputs = [tensor.astensor(x), gof.result.PythonResult()]
...@@ -157,39 +157,41 @@ class AddSS(gof.op.Op): #add two sparse matrices ...@@ -157,39 +157,41 @@ class AddSS(gof.op.Op): #add two sparse matrices
add_s_s = gof.op.constructor(AddSS) add_s_s = gof.op.constructor(AddSS)
if 0: class Dot(gof.op.Op):
class dot(gof.op.Op): def __init__(self, x, y):
""" def perform:
Attributes: #return numpy.dot(x, y)
grad_preserves_dense - an array of boolean flags (described below) def grad:
grad_preserves_dense controls whether gradients with respect to inputs are """
converted to dense matrices when the corresponding inputs are not in a Attributes:
SparseR wrapper. This can be a good idea when dot is in the middle of a grad_preserves_dense - an array of boolean flags (described below)
larger graph, because the types of gx and gy will match those of x and y.
This conversion might be annoying if the gradients are graph outputs though,
hence this mask.
"""
def __init__(self, *args, **kwargs):
gof.op.Op.__init__(self, **kwargs)
self.grad_preserves_dense = [True, True]
def gen_outputs(self): return [SparseR()]
def impl(x,y):
if hasattr(x, 'getnnz'):
# if x is sparse, then do this.
return x.dot(y)
else:
# if x is dense (and y is sparse), we do this
return y.transpose().dot(x.transpose()).transpose()
def grad(self, x, y, gz):
rval = [dot(gz, y.T), dot(x.T, gz)]
for i in 0,1:
if not isinstance(self.inputs[i], SparseR):
#assume it is a dense matrix
if self.grad_preserves_dense[i]:
rval[i] = dense_from_sparse(rval[i])
return rval
grad_preserves_dense controls whether gradients with respect to inputs are
converted to dense matrices when the corresponding inputs are not in a
SparseR wrapper. This can be a good idea when dot is in the middle of a
larger graph, because the types of gx and gy will match those of x and y.
This conversion might be annoying if the gradients are graph outputs though,
hence this mask.
"""
def __init__(self, *args, **kwargs):
gof.op.Op.__init__(self, **kwargs)
self.grad_preserves_dense = [True, True]
def gen_outputs(self): return [SparseR()]
def impl(x,y):
if hasattr(x, 'getnnz'):
# if x is sparse, then do this.
return x.dot(y)
else:
# if x is dense (and y is sparse), we do this
return y.transpose().dot(x.transpose()).transpose()
def grad(self, x, y, gz):
rval = [dot(gz, y.T), dot(x.T, gz)]
for i in 0,1:
if not isinstance(self.inputs[i], SparseR):
#assume it is a dense matrix
if self.grad_preserves_dense[i]:
rval[i] = dense_from_sparse(rval[i])
return rval
"""A ResultBase to store numpy.ndarray with basic accompanying Ops""" """A Result to store numpy.ndarray with basic accompanying Ops"""
import sys # for sys.maxint import sys # for sys.maxint
import inspect import inspect
import numpy import numpy
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result import gof.result
import gof.op import gof.op
...@@ -260,7 +260,7 @@ class Subtensor(Op, Viewer): ...@@ -260,7 +260,7 @@ class Subtensor(Op, Viewer):
debug = 0 debug = 0
def __init__(self, *args,**kwargs): def __init__(self, *args,**kwargs):
def as_tuple_result(obj): def as_tuple_result(obj):
if isinstance(obj, ResultBase): if isinstance(obj, Result):
return obj return obj
r = gof.result.PythonResult(None) r = gof.result.PythonResult(None)
if isinstance(obj, tuple): if isinstance(obj, tuple):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论