提交 ae5b2696 authored 作者: Frederic's avatar Frederic

pep8

上级 03435c8c
......@@ -14,15 +14,16 @@ import theano.tensor as T
class IfElseIfElseIf(PureOp):
def __init__(self, inplace=False):
self.inplace=inplace # check destroyhandler and others to ensure that a view_map with
# check destroyhandler and others to ensure that a view_map with
self.inplace = inplace
#multiple inputs can work
assert not self.inplace
def make_node(self, c1, t1, c2,t2,c3,t3,f3):
def make_node(self, c1, t1, c2, t2, c3, t3, f3):
assert t1.type == f3.type
assert t2.type == t3.type
assert t3.type == f3.type
return Apply(self, [c1,t1,c2,t2,c3,t3,f3], [t1.type()])
return Apply(self, [c1, t1, c2, t2, c3, t3, f3], [t1.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
......@@ -42,8 +43,9 @@ class IfElseIfElseIf(PureOp):
if not input_computed[1][0]:
return [1]
else:
output_computed[0][0]=1
output_registers[0][0]=outtype.filter(deepcopy(input_registers[1][0]))
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[1][0]))
return []
else:
if not input_computed[2][0]:
......@@ -55,7 +57,8 @@ class IfElseIfElseIf(PureOp):
return [3]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[3][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[3][0]))
return []
else:
if not input_computed[4][0]:
......@@ -67,30 +70,33 @@ class IfElseIfElseIf(PureOp):
return [5]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[5][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[5][0]))
return []
else:
if not input_computed[6][0]:
return [6]
else:
output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[6][0]))
output_registers[0][0] = outtype.filter(
deepcopy(input_registers[6][0]))
return []
thunk.lazy = True
return thunk
class NotImplementedOp(PureOp):
class E(Exception): pass
class E(Exception):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
def thunk():
raise self.E()
thunk.lazy=False
thunk.lazy = False
return thunk
......@@ -132,19 +138,20 @@ def more_complex_test():
x2 = T.scalar('x2')
c1 = T.scalar('c1')
c2 = T.scalar('c2')
t1 = ifelse(c1,x1,notimpl(x2))
t1 = ifelse(c1, x1, notimpl(x2))
t1.name = 't1'
t2 = t1*10
t2 = t1 * 10
t2.name = 't2'
t3 = ifelse(c2,t2, x1+t1)
t3 = ifelse(c2, t2, x1 + t1)
t3.name = 't3'
t4 = ifelseifelseif(T.eq(x1,x2), x1, T.eq(x1,5), x2, c2, t3, t3+0.5)
t4 = ifelseifelseif(T.eq(x1, x2), x1, T.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = 't4'
f = function([c1,c2,x1,x2], t4, mode=Mode(linker='vm', optimizer='fast_run'))
f = function([c1, c2, x1, x2], t4, mode=Mode(linker='vm',
optimizer='fast_run'))
print f(1, 0, numpy.array(10,dtype=x1.dtype),0)
assert f(1,0,numpy.array(10,dtype=x1.dtype),0) == 20.5
print f(1, 0, numpy.array(10, dtype=x1.dtype), 0)
assert f(1, 0, numpy.array(10, dtype=x1.dtype), 0) == 20.5
print '... passed'
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论