提交 65400e78 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

replace hardcoded T.dot equalities with isinstance checks

上级 67f13b01
......@@ -1531,11 +1531,11 @@ class Dot22(GemmRelated):
_dot22 = Dot22()
@local_optimizer([T.dot])
@local_optimizer([T._dot])
def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
if node.op != T.dot:
if not isinstance(node.op, T.Dot):
return
x, y = node.inputs
......
......@@ -410,7 +410,8 @@ def local_lift_transpose_through_dot(node):
if not (isinstance(node.op, T.DimShuffle)
and node.op.new_order == (1, 0)):
return False
if not (node.inputs[0].owner and node.inputs[0].owner.op == T.dot):
if not (node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, T.Dot)):
return False
x, y = node.inputs[0].owner.inputs
......
......@@ -14,7 +14,6 @@ from numpy import (arange, array, common_type, complex64, complex128, float32,
from numpy.testing import assert_array_almost_equal
#from numpy.testing import dec
#from numpy.testing.noseclasses import KnownFailureTest
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
_is_real_matrix, _gemm_canonicalize,
_factor_canonicalized, Gemm, Gemv,
......@@ -479,7 +478,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
on_unused_input='ignore')
nb_gemm = 0
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
if isinstance(node.op, T.Dot):
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph')
......@@ -562,7 +561,7 @@ def test_gemm_opt_double_gemm():
f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
if isinstance(node.op, T.Dot):
raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
......@@ -857,7 +856,8 @@ def test_dot22():
if dtype1 == dtype2:
assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
else:
assert T.dot in [x.op for x in topo], (dtype1, dtype2)
check = [isinstance(x.op, T.Dot) for x in topo]
assert any(check), (dtype1, dtype2)
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def cmp(a_shp, b_shp):
......@@ -919,8 +919,8 @@ def test_dot22scalar():
assert _dot22 in ops, (dtype1, dtype2,
dtype3, dtype4)
else:
assert T.dot in ops, (dtype1, dtype2,
dtype3, dtype4)
check = [isinstance(o, T.Dot) for o in ops]
assert any(check), (dtype1, dtype2, dtype3, dtype4)
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5, 5)):
av = rng.uniform(size=a_shp).astype(dtype1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论