提交 ae5b2696 authored 作者: Frederic's avatar Frederic

pep8

上级 03435c8c
...@@ -14,15 +14,16 @@ import theano.tensor as T ...@@ -14,15 +14,16 @@ import theano.tensor as T
class IfElseIfElseIf(PureOp): class IfElseIfElseIf(PureOp):
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace=inplace # check destroyhandler and others to ensure that a view_map with # check destroyhandler and others to ensure that a view_map with
self.inplace = inplace
#multiple inputs can work #multiple inputs can work
assert not self.inplace assert not self.inplace
def make_node(self, c1, t1, c2,t2,c3,t3,f3): def make_node(self, c1, t1, c2, t2, c3, t3, f3):
assert t1.type == f3.type assert t1.type == f3.type
assert t2.type == t3.type assert t2.type == t3.type
assert t3.type == f3.type assert t3.type == f3.type
return Apply(self, [c1,t1,c2,t2,c3,t3,f3], [t1.type()]) return Apply(self, [c1, t1, c2, t2, c3, t3, f3], [t1.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
...@@ -42,8 +43,9 @@ class IfElseIfElseIf(PureOp): ...@@ -42,8 +43,9 @@ class IfElseIfElseIf(PureOp):
if not input_computed[1][0]: if not input_computed[1][0]:
return [1] return [1]
else: else:
output_computed[0][0]=1 output_computed[0][0] = 1
output_registers[0][0]=outtype.filter(deepcopy(input_registers[1][0])) output_registers[0][0] = outtype.filter(
deepcopy(input_registers[1][0]))
return [] return []
else: else:
if not input_computed[2][0]: if not input_computed[2][0]:
...@@ -55,7 +57,8 @@ class IfElseIfElseIf(PureOp): ...@@ -55,7 +57,8 @@ class IfElseIfElseIf(PureOp):
return [3] return [3]
else: else:
output_computed[0][0] = 1 output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[3][0])) output_registers[0][0] = outtype.filter(
deepcopy(input_registers[3][0]))
return [] return []
else: else:
if not input_computed[4][0]: if not input_computed[4][0]:
...@@ -67,30 +70,33 @@ class IfElseIfElseIf(PureOp): ...@@ -67,30 +70,33 @@ class IfElseIfElseIf(PureOp):
return [5] return [5]
else: else:
output_computed[0][0] = 1 output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[5][0])) output_registers[0][0] = outtype.filter(
deepcopy(input_registers[5][0]))
return [] return []
else: else:
if not input_computed[6][0]: if not input_computed[6][0]:
return [6] return [6]
else: else:
output_computed[0][0] = 1 output_computed[0][0] = 1
output_registers[0][0] = outtype.filter(deepcopy(input_registers[6][0])) output_registers[0][0] = outtype.filter(
deepcopy(input_registers[6][0]))
return [] return []
thunk.lazy = True thunk.lazy = True
return thunk return thunk
class NotImplementedOp(PureOp): class NotImplementedOp(PureOp):
class E(Exception): pass class E(Exception):
pass
def make_node(self, x): def make_node(self, x):
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
def thunk(): def thunk():
raise self.E() raise self.E()
thunk.lazy=False thunk.lazy = False
return thunk return thunk
...@@ -132,19 +138,20 @@ def more_complex_test(): ...@@ -132,19 +138,20 @@ def more_complex_test():
x2 = T.scalar('x2') x2 = T.scalar('x2')
c1 = T.scalar('c1') c1 = T.scalar('c1')
c2 = T.scalar('c2') c2 = T.scalar('c2')
t1 = ifelse(c1,x1,notimpl(x2)) t1 = ifelse(c1, x1, notimpl(x2))
t1.name = 't1' t1.name = 't1'
t2 = t1*10 t2 = t1 * 10
t2.name = 't2' t2.name = 't2'
t3 = ifelse(c2,t2, x1+t1) t3 = ifelse(c2, t2, x1 + t1)
t3.name = 't3' t3.name = 't3'
t4 = ifelseifelseif(T.eq(x1,x2), x1, T.eq(x1,5), x2, c2, t3, t3+0.5) t4 = ifelseifelseif(T.eq(x1, x2), x1, T.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = 't4' t4.name = 't4'
f = function([c1,c2,x1,x2], t4, mode=Mode(linker='vm', optimizer='fast_run')) f = function([c1, c2, x1, x2], t4, mode=Mode(linker='vm',
optimizer='fast_run'))
print f(1, 0, numpy.array(10,dtype=x1.dtype),0) print f(1, 0, numpy.array(10, dtype=x1.dtype), 0)
assert f(1,0,numpy.array(10,dtype=x1.dtype),0) == 20.5 assert f(1, 0, numpy.array(10, dtype=x1.dtype), 0) == 20.5
print '... passed' print '... passed'
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论