提交 7790833a authored 作者: lamblin's avatar lamblin

Merge pull request #1132 from nouiz/misc2

Misc2
...@@ -205,6 +205,8 @@ if __name__ == "__main__": ...@@ -205,6 +205,8 @@ if __name__ == "__main__":
C2075/4.2 0.25s C2075/4.2 0.25s
GTX285/4.2 0.452s #cuda 3.0 seam faster? driver version? GTX285/4.2 0.452s #cuda 3.0 seam faster? driver version?
GTX460/4.0 0.45s
GTX580/3.2 0.203s GTX580/3.2 0.203s
GTX680/3.2 0.218s GTX680/3.2 0.218s
GTX480/3.2 0.237s GTX480/3.2 0.237s
......
...@@ -59,7 +59,9 @@ class HostFromGpu(GpuOp): ...@@ -59,7 +59,9 @@ class HostFromGpu(GpuOp):
def make_node(self, x): def make_node(self, x):
if not isinstance(x.type, CudaNdarrayType): if not isinstance(x.type, CudaNdarrayType):
raise TypeError(x) raise TypeError("Expected a Theano variable with type "
"CudaNdarrayType. Got %s with type %s" % (x,
x.type))
return Apply(self, [x], [tensor.TensorType(dtype=x.dtype, return Apply(self, [x], [tensor.TensorType(dtype=x.dtype,
broadcastable=x.broadcastable)()]) broadcastable=x.broadcastable)()])
...@@ -113,7 +115,9 @@ class GpuFromHost(GpuOp): ...@@ -113,7 +115,9 @@ class GpuFromHost(GpuOp):
def make_node(self, x): def make_node(self, x):
if not isinstance(x.type, tensor.TensorType): if not isinstance(x.type, tensor.TensorType):
raise TypeError(x) raise TypeError("Expected a Theano variable with type "
"TensorType. Got %s with type %s" % (x,
x.type))
return Apply(self, [x], [CudaNdarrayType(broadcastable=x.broadcastable, return Apply(self, [x], [CudaNdarrayType(broadcastable=x.broadcastable,
dtype=x.dtype)()]) dtype=x.dtype)()])
......
...@@ -6,16 +6,17 @@ import numpy as np ...@@ -6,16 +6,17 @@ import numpy as np
import theano import theano
from theano.printing import var_descriptor from theano.printing import var_descriptor
from cStringIO import StringIO from cStringIO import StringIO
from nose.plugins.skip import SkipTest
from theano import config from theano import config
from theano import shared from theano import shared
def sharedX(x, name=None): def sharedX(x, name=None):
x = np.cast[config.floatX](x) x = np.cast[config.floatX](x)
return shared(x, name) return shared(x, name)
def test_determinism_1(): def test_determinism_1():
# Tests that repeatedly running a script that compiles and # Tests that repeatedly running a script that compiles and
...@@ -24,7 +25,12 @@ def test_determinism_1(): ...@@ -24,7 +25,12 @@ def test_determinism_1():
# change. # change.
# This specific script is capable of catching a bug where # This specific script is capable of catching a bug where
# FunctionGraph.toposort was non-deterministic. # FunctionGraph.toposort was non-deterministic.
def run(replay, log = None): try:
import hashlib
except ImportError:
raise SkipTest('python version too old to do this test')
def run(replay, log=None):
if not replay: if not replay:
log = StringIO() log = StringIO()
...@@ -32,7 +38,6 @@ def test_determinism_1(): ...@@ -32,7 +38,6 @@ def test_determinism_1():
log = StringIO(log) log = StringIO(log)
record = Record(replay=replay, file_object=log) record = Record(replay=replay, file_object=log)
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
mode = RecordMode(record=record) mode = RecordMode(record=record)
...@@ -53,19 +58,20 @@ def test_determinism_1(): ...@@ -53,19 +58,20 @@ def test_determinism_1():
v_range.max(), v_range.max(),
]): ]):
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
s = sharedX(0., name='s_'+str(i)) s = sharedX(0., name='s_' + str(i))
updates.append((s, val)) updates.append((s, val))
for var in theano.gof.graph.ancestors(update for var, update in updates): for var in theano.gof.graph.ancestors(update for _, update in updates):
if var.name is not None and var.name is not 'b': if var.name is not None and var.name is not 'b':
if var.name[0] != 's' or len(var.name) != 2: if var.name[0] != 's' or len(var.name) != 2:
var.name = None var.name = None
for key in channels: for key in channels:
updates.append((s, channels[key])) updates.append((s, channels[key]))
f = theano.function([], mode=mode, updates=updates, on_unused_input='ignore', name='f') f = theano.function([], mode=mode, updates=updates,
on_unused_input='ignore', name='f')
for output in f.maker.fgraph.outputs: for output in f.maker.fgraph.outputs:
mode.record.handle_line(var_descriptor(output)+'\n') mode.record.handle_line(var_descriptor(output) + '\n')
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
f() f()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论