提交 98af764b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use shared() to allow the test to run on gpu without problems.

上级 9fcdcdba
...@@ -1472,14 +1472,14 @@ class BaseGemv(object): ...@@ -1472,14 +1472,14 @@ class BaseGemv(object):
x_v = x_v.astype("float32") x_v = x_v.astype("float32")
y_v = y_v.astype("float32") y_v = y_v.astype("float32")
alpha = T.dscalar('a') alpha = T.dscalar('alpha')
a = T.fmatrix('w') a = self.shared(a_v)
x = T.fvector('v') x = self.shared(x_v)
y = T.fvector('t') y = self.shared(y_v)
rval = T.dot(a, x) * alpha + y rval = T.dot(a, x) * alpha + y
f = theano.function([a, x, y, alpha], rval, mode=self.mode) f = theano.function([alpha], rval, mode=self.mode)
# this function is currently optimized so that the gemv is # this function is currently optimized so that the gemv is
# done inplace on a temporarily allocated-buffer, which is # done inplace on a temporarily allocated-buffer, which is
# then scaled by alpha and to t with a fused elemwise. # then scaled by alpha and to t with a fused elemwise.
...@@ -1491,7 +1491,7 @@ class BaseGemv(object): ...@@ -1491,7 +1491,7 @@ class BaseGemv(object):
assert node.outputs[0].dtype == 'float32' assert node.outputs[0].dtype == 'float32'
assert n_gemvs == 1, n_gemvs assert n_gemvs == 1, n_gemvs
self.assertFunctionContains1(f, self.gemv_inplace) self.assertFunctionContains1(f, self.gemv_inplace)
f(a_v, x_v, y_v, alpha_v) f(alpha_v)
class TestSgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin): class TestSgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论