提交 67016ad1 authored 作者: affanv14's avatar affanv14

add modes to test

上级 1a919959
...@@ -714,6 +714,8 @@ class Conv_opt_test(unittest.TestCase): ...@@ -714,6 +714,8 @@ class Conv_opt_test(unittest.TestCase):
if(direction == 0): if(direction == 0):
conv_op = abstract_conv.conv2d(inp1, conv_op = abstract_conv.conv2d(inp1,
inp2, inp2,
input_shapes[0],
input_shapes[1],
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
...@@ -738,9 +740,9 @@ class Conv_opt_test(unittest.TestCase): ...@@ -738,9 +740,9 @@ class Conv_opt_test(unittest.TestCase):
theano.config.metaopt.optimizer_including = include_tags theano.config.metaopt.optimizer_including = include_tags
theano.config.metaopt.optimizer_excluding = exclude_tags theano.config.metaopt.optimizer_excluding = exclude_tags
mode = theano.Mode().including('conv_meta') mode = mode_with_gpu.including('conv_meta')
ref_func = theano.function([], conv_op) ref_func = theano.function([], conv_op, mode=mode_with_gpu)
conv_func = theano.function([], conv_op, mode=mode) conv_func = theano.function([], conv_op, mode=mode)
assert any([isinstance(node.op, op) assert any([isinstance(node.op, op)
for node in conv_func.maker.fgraph.toposort()]) for node in conv_func.maker.fgraph.toposort()])
...@@ -780,9 +782,9 @@ class Conv_opt_test(unittest.TestCase): ...@@ -780,9 +782,9 @@ class Conv_opt_test(unittest.TestCase):
theano.config.metaopt.optimizer_including = include_tags theano.config.metaopt.optimizer_including = include_tags
theano.config.metaopt.optimizer_excluding = exclude_tags theano.config.metaopt.optimizer_excluding = exclude_tags
mode = theano.Mode().including('conv_meta') mode = mode_with_gpu.including('conv_meta')
ref_func = theano.function([], conv_op) ref_func = theano.function([], conv_op, mode=mode_with_gpu)
conv_func = theano.function([], conv_op, mode=mode) conv_func = theano.function([], conv_op, mode=mode)
if op is not None: if op is not None:
assert any([isinstance(node.op, op) assert any([isinstance(node.op, op)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论