提交 7a6d676f authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5688 from cooijmanstim/gpuarray-stack-trace

gpuarray: keep stack trace
...@@ -4,6 +4,7 @@ Node classes (`Apply`, `Variable`) and expression graph algorithms. ...@@ -4,6 +4,7 @@ Node classes (`Apply`, `Variable`) and expression graph algorithms.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import deque from collections import deque
import contextlib
from copy import copy from copy import copy
from itertools import count from itertools import count
...@@ -390,6 +391,8 @@ class Variable(Node): ...@@ -390,6 +391,8 @@ class Variable(Node):
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__)) self.auto_name = 'auto_' + str(next(self.__count__))
Variable.notify_construction_observers(self)
def __str__(self): def __str__(self):
"""Return a str representation of the Variable. """Return a str representation of the Variable.
...@@ -536,6 +539,22 @@ class Variable(Node): ...@@ -536,6 +539,22 @@ class Variable(Node):
d["tag"] = t d["tag"] = t
return d return d
# refer to doc in nodes_constructed.
construction_observers = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable): class Constant(Variable):
""" """
...@@ -1426,3 +1445,38 @@ def is_in_ancestors(l_node, f_node): ...@@ -1426,3 +1445,38 @@ def is_in_ancestors(l_node, f_node):
todo.append(cur) todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner) todo.extend(i.owner for i in cur.inputs if i.owner)
return False return False
@contextlib.contextmanager
def nodes_constructed():
"""
A contextmanager that is used in inherit_stack_trace and keeps track
of all the newly created varaible nodes inside an optimization. A list
of new_nodes is instantiated but will be filled in a lazy manner (when
Variable.notify_construction_observers is called).
`observer` is the entity that updates the new_nodes list.
construction_observers is a list inside Variable class and contains
a list of observer functions. The observer functions inside
construction_observers are only called when a variable node is
instantiated (where Variable.notify_construction_observers is called).
When the observer function is called, a new variable node is added to
the new_nodes list.
Parameters
----------
new_nodes
A list of all the variable nodes that are created inside the optimization.
yields
new_nodes list.
"""
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
...@@ -6,6 +6,7 @@ amount of useful generic optimization tools. ...@@ -6,6 +6,7 @@ amount of useful generic optimization tools.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import deque, defaultdict, OrderedDict from collections import deque, defaultdict, OrderedDict
import contextlib
import copy import copy
import inspect import inspect
import logging import logging
...@@ -2902,7 +2903,7 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -2902,7 +2903,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
def copy_stack_trace(from_var, to_var): def copy_stack_trace(from_var, to_var):
""" """
Copies the stack trace from one or more tensor variables to Copies the stack trace from one or more tensor variables to
one or more tensor variables. one or more tensor variables and returns the destination variables.
Parameters Parameters
---------- ----------
...@@ -2946,6 +2947,25 @@ def copy_stack_trace(from_var, to_var): ...@@ -2946,6 +2947,25 @@ def copy_stack_trace(from_var, to_var):
# Copy over stack traces from from_var to each variable to # Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before # to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
Contextmanager that copies the stack trace from one or more variable nodes to all
variable nodes constructed in the body. new_nodes is the list of all the newly created
variable nodes inside an optimization that is managed by graph.nodes_constructed().
Parameters
----------
from_var
Variable node or a list of variable nodes to copy stack traces from.
"""
with graph.nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
......
...@@ -15,6 +15,7 @@ from theano.tensor.basic import ( ...@@ -15,6 +15,7 @@ from theano.tensor.basic import (
from theano.gof import HideC, COp, ParamsType from theano.gof import HideC, COp, ParamsType
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gof.opt import copy_stack_trace
from collections import deque from collections import deque
...@@ -75,11 +76,11 @@ def as_gpuarray_variable(x, context_name): ...@@ -75,11 +76,11 @@ def as_gpuarray_variable(x, context_name):
# If we couldn't deal with transfers, then maybe it's a tensor # If we couldn't deal with transfers, then maybe it's a tensor
if isinstance(x.type, tensor.TensorType): if isinstance(x.type, tensor.TensorType):
return GpuFromHost(context_name)(x) return copy_stack_trace(x, GpuFromHost(context_name)(x))
# Try _as_GpuArrayVariable if possible # Try _as_GpuArrayVariable if possible
if hasattr(x, '_as_GpuArrayVariable'): if hasattr(x, '_as_GpuArrayVariable'):
return x._as_GpuArrayVariable(context_name) return copy_stack_trace(x, x._as_GpuArrayVariable(context_name))
# If it didn't work try for a constant # If it didn't work try for a constant
ctx = get_context(context_name) ctx = get_context(context_name)
......
...@@ -18,6 +18,7 @@ from theano.gradient import DisconnectedType, grad_not_implemented ...@@ -18,6 +18,7 @@ from theano.gradient import DisconnectedType, grad_not_implemented
from theano.gof import Optimizer, local_optimizer, COp, ParamsType, EnumList from theano.gof import Optimizer, local_optimizer, COp, ParamsType, EnumList
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.gof.opt import inherit_stack_trace
from theano.compile import optdb from theano.compile import optdb
from theano.compile.ops import shape_i, shape_i_op from theano.compile.ops import shape_i, shape_i_op
from theano.tensor.nnet import LogSoftmax, SoftmaxGrad from theano.tensor.nnet import LogSoftmax, SoftmaxGrad
...@@ -3127,9 +3128,11 @@ def local_abstractconv_cudnn(node): ...@@ -3127,9 +3128,11 @@ def local_abstractconv_cudnn(node):
if node.op.unshared: if node.op.unshared:
return None return None
if isinstance(node.op, AbstractConv2d): if isinstance(node.op, AbstractConv2d):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d): elif isinstance(node.op, AbstractConv3d):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
@local_optimizer([AbstractConv2d, AbstractConv2d_gradWeights, AbstractConv2d_gradInputs]) @local_optimizer([AbstractConv2d, AbstractConv2d_gradWeights, AbstractConv2d_gradInputs])
...@@ -3352,9 +3355,11 @@ def local_abstractconv_gw_cudnn(node): ...@@ -3352,9 +3355,11 @@ def local_abstractconv_gw_cudnn(node):
if node.op.unshared: if node.op.unshared:
return None return None
if isinstance(node.op, AbstractConv2d_gradWeights): if isinstance(node.op, AbstractConv2d_gradWeights):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d_gradWeights): elif isinstance(node.op, AbstractConv3d_gradWeights):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
@local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs]) @local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs])
...@@ -3365,9 +3370,11 @@ def local_abstractconv_gi_cudnn(node): ...@@ -3365,9 +3370,11 @@ def local_abstractconv_gi_cudnn(node):
if node.op.unshared: if node.op.unshared:
return None return None
if isinstance(node.op, AbstractConv2d_gradInputs): if isinstance(node.op, AbstractConv2d_gradInputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d_gradInputs): elif isinstance(node.op, AbstractConv3d_gradInputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
@inplace_allocempty(GpuDnnConv, 2) @inplace_allocempty(GpuDnnConv, 2)
...@@ -3384,7 +3391,6 @@ def local_dnn_convgw_inplace(node, inputs): ...@@ -3384,7 +3391,6 @@ def local_dnn_convgw_inplace(node, inputs):
def local_dnn_convgi_inplace(node, inputs): def local_dnn_convgi_inplace(node, inputs):
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
optdb.register('local_dnna_conv_inplace', optdb.register('local_dnna_conv_inplace',
tensor.opt.in2out(local_dnn_conv_inplace, tensor.opt.in2out(local_dnn_conv_inplace,
local_dnn_convgw_inplace, local_dnn_convgw_inplace,
...@@ -3654,11 +3660,12 @@ def local_dnn_reduction(node): ...@@ -3654,11 +3660,12 @@ def local_dnn_reduction(node):
if not cudnn.cudnnReduceTensorOp_t.has_alias(node.op.scalar_op.name): if not cudnn.cudnnReduceTensorOp_t.has_alias(node.op.scalar_op.name):
return return
return (GpuDnnReduction(node.op.scalar_op.name, with inherit_stack_trace(node.outputs):
node.op.axis, return (GpuDnnReduction(node.op.scalar_op.name,
node.op.acc_dtype, node.op.axis,
node.op.dtype, node.op.acc_dtype,
False)(node.inputs[0]),) node.op.dtype,
False)(node.inputs[0]),)
@register_opt('cudnn') @register_opt('cudnn')
......
差异被折叠。
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from theano import tensor, scalar as scal, Constant from theano import tensor, scalar as scal, Constant
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import inherit_stack_trace
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
...@@ -184,7 +185,8 @@ def alpha_merge(cls, alpha_in, beta_in): ...@@ -184,7 +185,8 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError: except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in] inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in] inputs[beta_in] = lr * targ.inputs[beta_in]
return maker(targ, *inputs) with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
...@@ -272,7 +274,8 @@ def output_merge(cls, alpha_in, beta_in, out_in): ...@@ -272,7 +274,8 @@ def output_merge(cls, alpha_in, beta_in, out_in):
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[out_in] = W inputs[out_in] = W
inputs[beta_in] = _one.clone() inputs[beta_in] = _one.clone()
return maker(targ, *inputs) with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
...@@ -326,7 +329,8 @@ def inplace_allocempty(op, idx): ...@@ -326,7 +329,8 @@ def inplace_allocempty(op, idx):
len(alloc.clients) > 1): len(alloc.clients) > 1):
alloc_op = GpuAllocEmpty(alloc.owner.op.dtype, alloc.owner.op.context_name) alloc_op = GpuAllocEmpty(alloc.owner.op.dtype, alloc.owner.op.context_name)
inputs[idx] = alloc_op(*alloc.owner.inputs) inputs[idx] = alloc_op(*alloc.owner.inputs)
return maker(node, inputs) with inherit_stack_trace(node.outputs):
return maker(node, inputs)
return opt return opt
return wrapper return wrapper
......
...@@ -146,6 +146,7 @@ from theano.gof import (utils, Op, view_roots, ...@@ -146,6 +146,7 @@ from theano.gof import (utils, Op, view_roots,
EquilibriumOptimizer, Apply, EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError) ReplacementDidntRemovedError)
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
import theano.scalar import theano.scalar
...@@ -1625,19 +1626,16 @@ def local_dot_to_dot22(node): ...@@ -1625,19 +1626,16 @@ def local_dot_to_dot22(node):
return return
if y.type.dtype in ['float16', 'float32', 'float64', 'complex64', 'complex128']: if y.type.dtype in ['float16', 'float32', 'float64', 'complex64', 'complex128']:
if x.ndim == 2 and y.ndim == 2: with inherit_stack_trace(node.outputs):
# print "local_dot_to_dot22: MM" if x.ndim == 2 and y.ndim == 2:
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1: if x.ndim == 2 and y.ndim == 1:
# print "local_dot_to_dot22: MV" return [_dot22(x, y.dimshuffle(0, 'x')).dimshuffle(0)]
return [_dot22(x, y.dimshuffle(0, 'x')).dimshuffle(0)] if x.ndim == 1 and y.ndim == 2:
if x.ndim == 1 and y.ndim == 2: return [_dot22(x.dimshuffle('x', 0), y).dimshuffle(1)]
# print "local_dot_to_dot22: VM" if x.ndim == 1 and y.ndim == 1:
return [_dot22(x.dimshuffle('x', 0), y).dimshuffle(1)] return [_dot22(x.dimshuffle('x', 0),
if x.ndim == 1 and y.ndim == 1: y.dimshuffle(0, 'x')).dimshuffle()]
# print "local_dot_to_dot22: VV"
return [_dot22(x.dimshuffle('x', 0),
y.dimshuffle(0, 'x')).dimshuffle()]
_logger.info('Not optimizing dot with inputs %s %s %s %s', _logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type) x, y, x.type, y.type)
...@@ -1646,19 +1644,22 @@ def local_dot_to_dot22(node): ...@@ -1646,19 +1644,22 @@ def local_dot_to_dot22(node):
@local_optimizer([gemm_no_inplace], inplace=True) @local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(node): def local_inplace_gemm(node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
return [gemm_inplace(*node.inputs)] with inherit_stack_trace(node.outputs):
return [gemm_inplace(*node.inputs)]
@local_optimizer([gemv_no_inplace], inplace=True) @local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(node): def local_inplace_gemv(node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)] with inherit_stack_trace(node.outputs):
return [gemv_inplace(*node.inputs)]
@local_optimizer([ger], inplace=True) @local_optimizer([ger], inplace=True)
def local_inplace_ger(node): def local_inplace_ger(node):
if node.op == ger: if node.op == ger:
return [ger_destructive(*node.inputs)] with inherit_stack_trace(node.outputs):
return [ger_destructive(*node.inputs)]
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1666,12 +1667,13 @@ def local_gemm_to_gemv(node): ...@@ -1666,12 +1667,13 @@ def local_gemm_to_gemv(node):
"""GEMM acting on row or column matrices -> GEMV.""" """GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
if z.broadcastable == x.broadcastable == (True, False): with inherit_stack_trace(node.outputs):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) if z.broadcastable == x.broadcastable == (True, False):
return [r.dimshuffle('x', 0)] r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
if z.broadcastable == y.broadcastable == (False, True): return [r.dimshuffle('x', 0)]
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) if z.broadcastable == y.broadcastable == (False, True):
return [r.dimshuffle(0, 'x')] r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, 'x')]
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1680,26 +1682,27 @@ def local_gemm_to_ger(node): ...@@ -1680,26 +1682,27 @@ def local_gemm_to_ger(node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]: if x.broadcastable[1] and y.broadcastable[0]:
# x and y are both vectors so this might qualifies for a GER with inherit_stack_trace(node.outputs):
xv = x.dimshuffle(0) # x and y are both vectors so this might qualifies for a GER
yv = y.dimshuffle(1) xv = x.dimshuffle(0)
try: yv = y.dimshuffle(1)
bval = T.get_scalar_constant_value(b) try:
except T.NotScalarConstantError: bval = T.get_scalar_constant_value(b)
# b isn't a constant, GEMM is doing useful pre-scaling except T.NotScalarConstantError:
return # b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv) if bval == 1: # best case a natural GER
return [rval] rval = ger(z, a, xv, yv)
elif bval == 0: # GER on zeros_like should be faster than GEMM return [rval]
zeros = T.zeros([x.shape[0], y.shape[1]], x.dtype) elif bval == 0: # GER on zeros_like should be faster than GEMM
rval = ger(zeros, a, xv, yv) zeros = T.zeros([x.shape[0], y.shape[1]], x.dtype)
return [rval] rval = ger(zeros, a, xv, yv)
else: return [rval]
# if bval is another constant, then z is being usefully else:
# pre-scaled and GER isn't really the right tool for the job. # if bval is another constant, then z is being usefully
return # pre-scaled and GER isn't really the right tool for the job.
return
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
...@@ -1708,37 +1711,38 @@ def local_gemm_to_ger(node): ...@@ -1708,37 +1711,38 @@ def local_gemm_to_ger(node):
def local_dot22_to_ger_or_gemv(node): def local_dot22_to_ger_or_gemv(node):
"""dot22 computing an outer-product -> GER.""" """dot22 computing an outer-product -> GER."""
if node.op == _dot22: if node.op == _dot22:
x, y = node.inputs with inherit_stack_trace(node.outputs):
xb = x.broadcastable x, y = node.inputs
yb = y.broadcastable xb = x.broadcastable
one = T.as_tensor_variable(np.asarray(1, dtype=x.dtype)) yb = y.broadcastable
zero = T.as_tensor_variable(np.asarray(0, dtype=x.dtype)) one = T.as_tensor_variable(np.asarray(1, dtype=x.dtype))
if xb[1] and yb[0]: zero = T.as_tensor_variable(np.asarray(0, dtype=x.dtype))
# x and y are both vectors so this might qualifies for a GER if xb[1] and yb[0]:
xv = x.dimshuffle(0) # x and y are both vectors so this might qualifies for a GER
yv = y.dimshuffle(1) xv = x.dimshuffle(0)
zeros = T.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) yv = y.dimshuffle(1)
rval = ger(zeros, one, xv, yv) zeros = T.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
return [rval] rval = ger(zeros, one, xv, yv)
if xb[0] and yb[1]: return [rval]
# x and y are both vectors so this qualifies for a sdot / ddot if xb[0] and yb[1]:
# TODO: Theano doesn't have a sdot, but gemv is better than _dot22 # x and y are both vectors so this qualifies for a sdot / ddot
xv = x.dimshuffle(1) # TODO: Theano doesn't have a sdot, but gemv is better than _dot22
zeros = T.AllocEmpty(x.dtype)(1) xv = x.dimshuffle(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) zeros = T.AllocEmpty(x.dtype)(1)
return [rval.dimshuffle('x', 0)] rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
if xb[0] and not yb[0] and not yb[1]: return [rval.dimshuffle('x', 0)]
# x is vector, y is matrix so try gemv if xb[0] and not yb[0] and not yb[1]:
xv = x.dimshuffle(1) # x is vector, y is matrix so try gemv
zeros = T.AllocEmpty(x.dtype)(y.shape[1]) xv = x.dimshuffle(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) zeros = T.AllocEmpty(x.dtype)(y.shape[1])
return [rval.dimshuffle('x', 0)] rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
if not xb[0] and not xb[1] and yb[1]: return [rval.dimshuffle('x', 0)]
# x is matrix, y is vector, try gemv if not xb[0] and not xb[1] and yb[1]:
yv = y.dimshuffle(0) # x is matrix, y is vector, try gemv
zeros = T.AllocEmpty(x.dtype)(x.shape[0]) yv = y.dimshuffle(0)
rval = gemv_no_inplace(zeros, one, x, yv, zero) zeros = T.AllocEmpty(x.dtype)(x.shape[0])
return [rval.dimshuffle(0, 'x')] rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, 'x')]
################################# #################################
......
...@@ -43,6 +43,7 @@ from theano.tensor import DimShuffle, Subtensor ...@@ -43,6 +43,7 @@ from theano.tensor import DimShuffle, Subtensor
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
from theano.gof.opt import copy_stack_trace
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -57,10 +58,13 @@ def local_max_and_argmax(node): ...@@ -57,10 +58,13 @@ def local_max_and_argmax(node):
axis = node.op.get_params(node) axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
return [None, T.Argmax(axis)(node.inputs[0])] new = T.Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize @register_uncanonicalize
...@@ -84,8 +88,8 @@ def local_max_to_min(node): ...@@ -84,8 +88,8 @@ def local_max_to_min(node):
max.owner.op.scalar_op == scal.maximum): max.owner.op.scalar_op == scal.maximum):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum, new = CAReduce(scal.minimum, max.owner.op.axis)(neg.owner.inputs[0])
max.owner.op.axis)(neg.owner.inputs[0])] return [copy_stack_trace(node.outputs[0], new)]
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论