提交 2a6e5d6f authored 作者: nouiz's avatar nouiz

Merge pull request #1186 from goodfeli/fix_constant

fix a NotScalarConstantError issue
...@@ -1615,7 +1615,7 @@ def local_gemm_to_ger(node): ...@@ -1615,7 +1615,7 @@ def local_gemm_to_ger(node):
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
try: try:
bval = T.get_scalar_constant_value(b) bval = T.get_scalar_constant_value(b)
except TypeError: except T.NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling # b isn't a constant, GEMM is doing useful pre-scaling
return return
......
...@@ -1603,6 +1603,13 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1603,6 +1603,13 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
self.A, self.a, self.x.dimshuffle(0, 'x'), self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner) self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_b_nonconst_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt"""
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.a).owner)
def test_outer(self): def test_outer(self):
f = self.function([self.x, self.y], T.outer(self.x, self.y)) f = self.function([self.x, self.y], T.outer(self.x, self.y))
self.assertFunctionContains(f, self.ger_destructive) self.assertFunctionContains(f, self.ger_destructive)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论