提交 2a6e5d6f authored 作者: nouiz's avatar nouiz

Merge pull request #1186 from goodfeli/fix_constant

fix a NotScalarConstantError issue
......@@ -1615,7 +1615,7 @@ def local_gemm_to_ger(node):
yv = y.dimshuffle(1)
try:
bval = T.get_scalar_constant_value(b)
except TypeError:
except T.NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
......
......@@ -1603,6 +1603,13 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_b_nonconst_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt"""
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.a).owner)
def test_outer(self):
f = self.function([self.x, self.y], T.outer(self.x, self.y))
self.assertFunctionContains(f, self.ger_destructive)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论