提交 77e6c81c authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5281 from nouiz/config_blas

Fix blas detection in corner case
......@@ -1255,8 +1255,10 @@ def default_blas_ldflags():
lib_path = os.path.join(sys.prefix, 'DLLs')
flags = ['-L"%s"' % lib_path]
else:
lib_path = blas_info.get('library_dirs', [])[0]
flags = ['-L%s' % lib_path]
lib_path = blas_info.get('library_dirs', [])
flags = []
if lib_path:
flags = ['-L%s' % lib_path[0]]
flags += ['-l%s' % l for l in ["mkl_core",
"mkl_intel_thread",
"mkl_rt"]]
......
......@@ -249,6 +249,21 @@ class test_gpu_ifelse(test_ifelse.test_ifelse):
def get_ifelse(self, n):
return theano.ifelse.IfElse(n, gpu=True, as_view=True)
def test_lifter_with_inputs_of_graph(self):
x = tensor.vector()
cond = tensor.iscalar()
f = theano.function([x, cond],
theano.ifelse.ifelse(cond, x.mean(), x.sum()),
mode=mode_with_gpu)
assert f(numpy.float32([1, 2, 3]), 0) == 6
x = tensor.vector()
cond = tensor.scalar()
f = theano.function([x, cond],
theano.ifelse.ifelse(cond, x.mean(), x.sum()),
mode=mode_with_gpu)
assert f(numpy.float32([1, 2, 3]), 0) == 6
def test_print_op():
""" Test that print ops don't block gpu optimization"""
......
......@@ -167,10 +167,10 @@ class IfElse(Op):
"Wrong number of arguments to make_node: "
"expected %d, got %d" % (2 * self.n_outs, len(args))
)
c = theano.tensor.as_tensor_variable(c)
if not self.gpu:
# When gpu is true, we are given only cuda ndarrays, and we want
# to keep them be cuda ndarrays
c = theano.tensor.as_tensor_variable(c)
nw_args = []
for x in args:
if isinstance(x, theano.Variable):
......
......@@ -1280,6 +1280,19 @@ def test_grad_useless_sum():
[-1.]])
def test_elemwise_grad_broadcast():
# This crashed in the past.
x = tensor.tensor(dtype='float32',
broadcastable=(True, False, False, False))
y = tensor.tensor(dtype='float32',
broadcastable=(True, True, False, False))
theano.grad(theano.tensor.tanh(x).sum(), x)
theano.grad(theano.tensor.tanh(x + y).sum(), y)
theano.grad(theano.tensor.tanh(x + y).sum(), [x, y])
def test_clip_grad_int():
# test that integers don't crash clip gradient
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论