提交 d1be796e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove useless checks guaranteed by tracks

上级 4c40efa2
...@@ -636,93 +636,89 @@ def local_inplace_ger(fgraph, node): ...@@ -636,93 +636,89 @@ def local_inplace_ger(fgraph, node):
@node_rewriter([gemm_no_inplace]) @node_rewriter([gemm_no_inplace])
def local_gemm_to_gemv(fgraph, node): def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV.""" """GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace: z, a, x, y, b = node.inputs
z, a, x, y, b = node.inputs if z.broadcastable == x.broadcastable == (True, False):
if z.broadcastable == x.broadcastable == (True, False): r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) new_out = [r.dimshuffle("x", 0)]
new_out = [r.dimshuffle("x", 0)] elif z.broadcastable == y.broadcastable == (False, True):
elif z.broadcastable == y.broadcastable == (False, True): r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) new_out = [r.dimshuffle(0, "x")]
new_out = [r.dimshuffle(0, "x")] else:
else: return
return copy_stack_trace(node.outputs, new_out)
copy_stack_trace(node.outputs, new_out) return new_out
return new_out
@node_rewriter([gemm_no_inplace]) @node_rewriter([gemm_no_inplace])
def local_gemm_to_ger(fgraph, node): def local_gemm_to_ger(fgraph, node):
"""GEMM computing an outer-product -> GER.""" """GEMM computing an outer-product -> GER."""
if node.op == gemm_no_inplace: z, a, x, y, b = node.inputs
z, a, x, y, b = node.inputs if x.broadcastable[1] and y.broadcastable[0]:
if x.broadcastable[1] and y.broadcastable[0]: # x and y are both vectors so this might qualifies for a GER
# x and y are both vectors so this might qualifies for a GER xv = x.dimshuffle(0)
xv = x.dimshuffle(0) yv = y.dimshuffle(1)
yv = y.dimshuffle(1) try:
try: bval = ptb.get_underlying_scalar_constant_value(b)
bval = ptb.get_underlying_scalar_constant_value(b) except NotScalarConstantError:
except NotScalarConstantError: # b isn't a constant, GEMM is doing useful pre-scaling
# b isn't a constant, GEMM is doing useful pre-scaling return
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
new_out = [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv)
new_out = [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
copy_stack_trace(node.outputs, new_out)
return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline if bval == 1: # best case a natural GER
# working rval = ger(z, a, xv, yv)
@node_rewriter([_dot22]) new_out = [rval]
def local_dot22_to_ger_or_gemv(fgraph, node): elif bval == 0: # GER on zeros_like should be faster than GEMM
"""dot22 computing an outer-product -> GER.""" zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype)
if node.op == _dot22: rval = ger(zeros, a, xv, yv)
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
new_out = [rval] new_out = [rval]
elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = ptb.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
new_out = [rval.dimshuffle(0, "x")]
else: else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return return
copy_stack_trace(node.outputs, new_out) copy_stack_trace(node.outputs, new_out)
return new_out return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working
@node_rewriter([_dot22])
def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER."""
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
new_out = [rval]
elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = ptb.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
new_out = [rval.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
################################# #################################
# #
# Set up the BlasOpt optimizer # Set up the BlasOpt optimizer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论