提交 1cbd5f2e authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Move tests.gof.test_opt's test types to tests.gof.utils

上级 d8a82f73
import theano.tensor as tt import theano.tensor as tt
from tests.gof.utils import MyType, MyVariable, op1, op2, op3, op4, op5, op6, op_y, op_z
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import ( from theano.gof.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
...@@ -15,82 +16,11 @@ from theano.gof.opt import ( ...@@ -15,82 +16,11 @@ from theano.gof.opt import (
pre_greedy_local_optimizer, pre_greedy_local_optimizer,
theano, theano,
) )
from theano.gof.type import Type
from theano.tensor.opt import constant_folding from theano.tensor.opt import constant_folding
from theano.tensor.subtensor import AdvancedSubtensor from theano.tensor.subtensor import AdvancedSubtensor
from theano.tensor.type_other import MakeSlice, SliceConstant, slicetype from theano.tensor.type_other import MakeSlice, SliceConstant, slicetype
def is_variable(x):
if not isinstance(x, Variable):
raise TypeError("not a Variable", x)
return x
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
def MyVariable(name):
return Variable(MyType(), None, None, name=name)
class MyOp(Op):
def __init__(self, name, dmap=None, x=None):
self.name = name
if dmap is None:
dmap = {}
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
)
return rval
def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None:
return hash(self.x)
else:
return id(self)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
op4 = MyOp("Op4")
op5 = MyOp("Op5")
op6 = MyOp("Op6")
op_d = MyOp("OpD", {0: [0]})
op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
def inputs(): def inputs():
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
......
from theano.gof.graph import Apply, Variable
from theano.gof.op import Op
from theano.gof.type import Type
def is_variable(x):
if not isinstance(x, Variable):
raise TypeError("not a Variable", x)
return x
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
def MyVariable(name):
return Variable(MyType(), None, None, name=name)
class MyOp(Op):
def __init__(self, name, dmap=None, x=None):
self.name = name
if dmap is None:
dmap = {}
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
)
return rval
def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None:
return hash(self.x)
else:
return id(self)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
op4 = MyOp("Op4")
op5 = MyOp("Op5")
op6 = MyOp("Op6")
op_d = MyOp("OpD", {0: [0]})
op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论