提交 bd3cac42 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #260 from delallea/lazy_import

Fixed test_lazy.py and test_vm.py Everything looks fine.
...@@ -6,7 +6,7 @@ from theano.gof.op import PureOp ...@@ -6,7 +6,7 @@ from theano.gof.op import PureOp
from theano.gof import Apply, generic, Container from theano.gof import Apply, generic, Container
from theano.gof.link import LocalLinker, map_storage, add_clear_storage from theano.gof.link import LocalLinker, map_storage, add_clear_storage
from theano import function, Mode from theano import function, Mode
from theano.lazycond import ifelse from theano.ifelse import ifelse
import theano.tensor as T import theano.tensor as T
......
...@@ -15,7 +15,7 @@ from theano.gof import OpWiseCLinker ...@@ -15,7 +15,7 @@ from theano.gof import OpWiseCLinker
from theano.compile import Mode from theano.compile import Mode
from theano import tensor from theano import tensor
from theano.lazycond import ifelse from theano.ifelse import ifelse
import theano import theano
class TestCallbacks(unittest.TestCase): class TestCallbacks(unittest.TestCase):
...@@ -26,8 +26,9 @@ class TestCallbacks(unittest.TestCase): ...@@ -26,8 +26,9 @@ class TestCallbacks(unittest.TestCase):
self.n_callbacks = {} self.n_callbacks = {}
def callback(self, node, thunk, storage_map, compute_map): def callback(self, node, thunk, storage_map, compute_map):
self.n_callbacks.setdefault(node.op, 0) key = node.op.__class__.__name__
self.n_callbacks[node.op] += 1 self.n_callbacks.setdefault(key, 0)
self.n_callbacks[key] += 1
def test_callback(self): def test_callback(self):
a, b, c = tensor.scalars('abc') a, b, c = tensor.scalars('abc')
...@@ -50,7 +51,7 @@ class TestCallbacks(unittest.TestCase): ...@@ -50,7 +51,7 @@ class TestCallbacks(unittest.TestCase):
linker=vm.VM_Linker(callback=self.callback))) linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3) f(1, 2, 3)
assert self.n_callbacks[ifelse] == 2 assert self.n_callbacks['IfElse'] == 2
def test_speed(): def test_speed():
...@@ -132,7 +133,7 @@ def test_speed_lazy(): ...@@ -132,7 +133,7 @@ def test_speed_lazy():
def build_graph(x, depth=5): def build_graph(x, depth=5):
z = x z = x
for d in range(depth): for d in range(depth):
z = ifelse(z> 0, -z, z) z = ifelse(z[0] > 0, -z, z)
return z return z
def time_linker(name, linker): def time_linker(name, linker):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论