提交 1e6bbdef authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.tensor.blas

上级 1ff084bf
...@@ -161,7 +161,7 @@ from theano.gof.opt import inherit_stack_trace ...@@ -161,7 +161,7 @@ from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.scalar import bool as bool_t from theano.scalar import bool as bool_t
from theano.tensor import basic as T from theano.tensor import basic as tt
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import in2out, local_dimshuffle_lift from theano.tensor.opt import in2out, local_dimshuffle_lift
...@@ -246,11 +246,11 @@ class Gemv(Op): ...@@ -246,11 +246,11 @@ class Gemv(Op):
return "%s{no_inplace}" % self.__class__.__name__ return "%s{no_inplace}" % self.__class__.__name__
def make_node(self, y, alpha, A, x, beta): def make_node(self, y, alpha, A, x, beta):
y = T.as_tensor_variable(y) y = tt.as_tensor_variable(y)
x = T.as_tensor_variable(x) x = tt.as_tensor_variable(x)
A = T.as_tensor_variable(A) A = tt.as_tensor_variable(A)
alpha = T.as_tensor_variable(alpha) alpha = tt.as_tensor_variable(alpha)
beta = T.as_tensor_variable(beta) beta = tt.as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype: if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError( raise TypeError(
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
...@@ -340,10 +340,10 @@ class Ger(Op): ...@@ -340,10 +340,10 @@ class Ger(Op):
return "%s{non-destructive}" % self.__class__.__name__ return "%s{non-destructive}" % self.__class__.__name__
def make_node(self, A, alpha, x, y): def make_node(self, A, alpha, x, y):
A = T.as_tensor_variable(A) A = tt.as_tensor_variable(A)
y = T.as_tensor_variable(y) y = tt.as_tensor_variable(y)
x = T.as_tensor_variable(x) x = tt.as_tensor_variable(x)
alpha = T.as_tensor_variable(alpha) alpha = tt.as_tensor_variable(alpha)
if not (A.dtype == x.dtype == y.dtype == alpha.dtype): if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
raise TypeError( raise TypeError(
"ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
...@@ -898,7 +898,7 @@ class Gemm(GemmRelated): ...@@ -898,7 +898,7 @@ class Gemm(GemmRelated):
return rval return rval
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs)) inputs = list(map(tt.as_tensor_variable, inputs))
if len(inputs) != 5: if len(inputs) != 5:
raise TypeError( raise TypeError(
"Wrong number of inputs for %s (expected 5, got %s)" "Wrong number of inputs for %s (expected 5, got %s)"
...@@ -1117,7 +1117,7 @@ def _as_scalar(res, dtype=None): ...@@ -1117,7 +1117,7 @@ def _as_scalar(res, dtype=None):
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if np.all(res.type.broadcastable): if np.all(res.type.broadcastable):
while res.owner and isinstance(res.owner.op, T.DimShuffle): while res.owner and isinstance(res.owner.op, tt.DimShuffle):
res = res.owner.inputs[0] res = res.owner.inputs[0]
# may still have some number of True's # may still have some number of True's
if res.type.broadcastable: if res.type.broadcastable:
...@@ -1131,7 +1131,7 @@ def _as_scalar(res, dtype=None): ...@@ -1131,7 +1131,7 @@ def _as_scalar(res, dtype=None):
# as the cast of the scalar can be done before or after the dot22 # as the cast of the scalar can be done before or after the dot22
# and this will give the same result. # and this will give the same result.
if theano.scalar.upcast(res.dtype, dtype) == dtype: if theano.scalar.upcast(res.dtype, dtype) == dtype:
return T.cast(rval, dtype) return tt.cast(rval, dtype)
else: else:
return None return None
...@@ -1171,7 +1171,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1171,7 +1171,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
# and the dot22. local_dot_to_dot22 in particular will put in such things. # and the dot22. local_dot_to_dot22 in particular will put in such things.
if ( if (
M.owner M.owner
and isinstance(M.owner.op, T.DimShuffle) and isinstance(M.owner.op, tt.DimShuffle)
and M.owner.inputs[0].owner and M.owner.inputs[0].owner
and isinstance(M.owner.inputs[0].owner.op, Dot22) and isinstance(M.owner.inputs[0].owner.op, Dot22)
): ):
...@@ -1262,24 +1262,24 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1262,24 +1262,24 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
rval.append((scale, r)) rval.append((scale, r))
return rval return rval
if r.owner and r.owner.op == T.sub: if r.owner and r.owner.op == tt.sub:
_gemm_canonicalize(r.owner.inputs[0], scale, rval, 1) _gemm_canonicalize(r.owner.inputs[0], scale, rval, 1)
_gemm_canonicalize(r.owner.inputs[1], -scale, rval, 1) _gemm_canonicalize(r.owner.inputs[1], -scale, rval, 1)
elif r.owner and r.owner.op == T.add: elif r.owner and r.owner.op == tt.add:
for i in r.owner.inputs: for i in r.owner.inputs:
_gemm_canonicalize(i, scale, rval, 1) _gemm_canonicalize(i, scale, rval, 1)
elif r.owner and r.owner.op == T.neg: elif r.owner and r.owner.op == tt.neg:
_gemm_canonicalize(r.owner.inputs[0], -scale, rval, 1) _gemm_canonicalize(r.owner.inputs[0], -scale, rval, 1)
elif r.owner and r.owner.op == T.mul: elif r.owner and r.owner.op == tt.mul:
scalars = [] scalars = []
vectors = [] vectors = []
matrices = [] matrices = []
for i in r.owner.inputs: for i in r.owner.inputs:
if np.all(i.type.broadcastable): if np.all(i.type.broadcastable):
while i.owner and isinstance(i.owner.op, T.DimShuffle): while i.owner and isinstance(i.owner.op, tt.DimShuffle):
i = i.owner.inputs[0] i = i.owner.inputs[0]
if i.type.broadcastable: if i.type.broadcastable:
scalars.append(i.dimshuffle()) scalars.append(i.dimshuffle())
...@@ -1301,7 +1301,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1301,7 +1301,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
elif len(scalars) == 1: elif len(scalars) == 1:
_gemm_canonicalize(m, scaled(scalars[0]), rval, 1) _gemm_canonicalize(m, scaled(scalars[0]), rval, 1)
else: else:
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1) _gemm_canonicalize(m, tt.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
elif len(vectors) == 1: elif len(vectors) == 1:
assert len(matrices) == 0 assert len(matrices) == 0
v = vectors[0] v = vectors[0]
...@@ -1310,7 +1310,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1310,7 +1310,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
elif len(scalars) == 1: elif len(scalars) == 1:
_gemm_canonicalize(v, scaled(scalars[0]), rval, 1) _gemm_canonicalize(v, scaled(scalars[0]), rval, 1)
else: else:
_gemm_canonicalize(v, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1) _gemm_canonicalize(v, tt.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
else: # lets not open this up else: # lets not open this up
rval.append((scale, r)) rval.append((scale, r))
else: else:
...@@ -1372,9 +1372,9 @@ def _gemm_from_factored_list(lst): ...@@ -1372,9 +1372,9 @@ def _gemm_from_factored_list(lst):
# sM can be a tuple of 2 elements or a theano variable. # sM can be a tuple of 2 elements or a theano variable.
if isinstance(sM, tuple): if isinstance(sM, tuple):
sm0, sm1 = sM sm0, sm1 = sM
sm0 = T.as_tensor_variable(sm0) sm0 = tt.as_tensor_variable(sm0)
if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
lst2.append((T.cast(sm0, sm1.dtype), sM[1])) lst2.append((tt.cast(sm0, sm1.dtype), sM[1]))
lst = lst2 lst = lst2
...@@ -1411,7 +1411,7 @@ def _gemm_from_factored_list(lst): ...@@ -1411,7 +1411,7 @@ def _gemm_from_factored_list(lst):
] ]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: if len(add_inputs) > 1:
rval = [T.add(*add_inputs)] rval = [tt.add(*add_inputs)]
else: else:
rval = add_inputs rval = add_inputs
# print "RETURNING GEMM THING", rval # print "RETURNING GEMM THING", rval
...@@ -1495,7 +1495,7 @@ class GemmOptimizer(Optimizer): ...@@ -1495,7 +1495,7 @@ class GemmOptimizer(Optimizer):
nodelist.reverse() nodelist.reverse()
for node in nodelist: for node in nodelist:
if not ( if not (
isinstance(node.op, T.Elemwise) isinstance(node.op, tt.Elemwise)
and isinstance( and isinstance(
node.op.scalar_op, node.op.scalar_op,
( (
...@@ -1606,8 +1606,8 @@ class Dot22(GemmRelated): ...@@ -1606,8 +1606,8 @@ class Dot22(GemmRelated):
check_input = False check_input = False
def make_node(self, x, y): def make_node(self, x, y):
x = T.as_tensor_variable(x) x = tt.as_tensor_variable(x)
y = T.as_tensor_variable(y) y = tt.as_tensor_variable(y)
dtypes = ("float16", "float32", "float64", "complex64", "complex128") dtypes = ("float16", "float32", "float64", "complex64", "complex128")
if x.type.ndim != 2 or x.type.dtype not in dtypes: if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x) raise TypeError(x)
...@@ -1616,7 +1616,7 @@ class Dot22(GemmRelated): ...@@ -1616,7 +1616,7 @@ class Dot22(GemmRelated):
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError("dtype mismatch to Dot22") raise TypeError("dtype mismatch to Dot22")
bz = (x.type.broadcastable[0], y.type.broadcastable[1]) bz = (x.type.broadcastable[0], y.type.broadcastable[1])
outputs = [T.tensor(x.type.dtype, bz)] outputs = [tt.tensor(x.type.dtype, bz)]
return Apply(self, [x, y], outputs) return Apply(self, [x, y], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
...@@ -1686,11 +1686,11 @@ class Dot22(GemmRelated): ...@@ -1686,11 +1686,11 @@ class Dot22(GemmRelated):
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([T.Dot]) @local_optimizer([tt.Dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below # produces a dot(dimshuffle,dimshuffle) of form 4 below
if not isinstance(node.op, T.Dot): if not isinstance(node.op, tt.Dot):
return return
x, y = node.inputs x, y = node.inputs
...@@ -1759,8 +1759,8 @@ def local_gemm_to_ger(node): ...@@ -1759,8 +1759,8 @@ def local_gemm_to_ger(node):
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
try: try:
bval = T.get_scalar_constant_value(b) bval = tt.get_scalar_constant_value(b)
except T.NotScalarConstantError: except tt.NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling # b isn't a constant, GEMM is doing useful pre-scaling
return return
...@@ -1768,7 +1768,7 @@ def local_gemm_to_ger(node): ...@@ -1768,7 +1768,7 @@ def local_gemm_to_ger(node):
rval = ger(z, a, xv, yv) rval = ger(z, a, xv, yv)
return [rval] return [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = T.zeros([x.shape[0], y.shape[1]], x.dtype) zeros = tt.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv) rval = ger(zeros, a, xv, yv)
return [rval] return [rval]
else: else:
...@@ -1787,32 +1787,32 @@ def local_dot22_to_ger_or_gemv(node): ...@@ -1787,32 +1787,32 @@ def local_dot22_to_ger_or_gemv(node):
x, y = node.inputs x, y = node.inputs
xb = x.broadcastable xb = x.broadcastable
yb = y.broadcastable yb = y.broadcastable
one = T.as_tensor_variable(np.asarray(1, dtype=x.dtype)) one = tt.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = T.as_tensor_variable(np.asarray(0, dtype=x.dtype)) zero = tt.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]: if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER # x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
zeros = T.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) zeros = tt.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv) rval = ger(zeros, one, xv, yv)
return [rval] return [rval]
if xb[0] and yb[1]: if xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot # x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Theano doesn't have a sdot, but gemv is better than _dot22 # TODO: Theano doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
zeros = T.AllocEmpty(x.dtype)(1) zeros = tt.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)] return [rval.dimshuffle("x", 0)]
if xb[0] and not yb[0] and not yb[1]: if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv # x is vector, y is matrix so try gemv
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
zeros = T.AllocEmpty(x.dtype)(y.shape[1]) zeros = tt.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)] return [rval.dimshuffle("x", 0)]
if not xb[0] and not xb[1] and yb[1]: if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv # x is matrix, y is vector, try gemv
yv = y.dimshuffle(0) yv = y.dimshuffle(0)
zeros = T.AllocEmpty(x.dtype)(x.shape[0]) zeros = tt.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero) rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, "x")] return [rval.dimshuffle(0, "x")]
...@@ -1891,7 +1891,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1891,7 +1891,7 @@ class Dot22Scalar(GemmRelated):
raise TypeError("Dot22Scalar requires float or complex args", a.dtype) raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
bz = [x.type.broadcastable[0], y.type.broadcastable[1]] bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [tt.tensor(x.type.dtype, bz)]
return Apply(self, [x, y, a], outputs) return Apply(self, [x, y, a], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
...@@ -1956,7 +1956,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1956,7 +1956,7 @@ class Dot22Scalar(GemmRelated):
_dot22scalar = Dot22Scalar() _dot22scalar = Dot22Scalar()
@local_optimizer([T.mul]) @local_optimizer([tt.mul])
def local_dot22_to_dot22scalar(node): def local_dot22_to_dot22scalar(node):
""" """
Notes Notes
...@@ -1981,7 +1981,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1981,7 +1981,7 @@ def local_dot22_to_dot22scalar(node):
inputs) inputs)
""" """
if node.op != T.mul: if node.op != tt.mul:
return False return False
i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
if not any(i_dot22): if not any(i_dot22):
...@@ -1999,7 +1999,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1999,7 +1999,7 @@ def local_dot22_to_dot22scalar(node):
# The canonizer should have merged those mul together. # The canonizer should have merged those mul together.
i_mul = [ i_mul = [
x.owner x.owner
and x.owner.op == T.mul and x.owner.op == tt.mul
and any([_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs]) and any([_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs])
for x in node.inputs for x in node.inputs
] ]
...@@ -2029,7 +2029,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2029,7 +2029,7 @@ def local_dot22_to_dot22scalar(node):
[x.type for x in node.inputs], [x.type for x in node.inputs],
) )
return False return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) a = tt.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
...@@ -2044,7 +2044,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2044,7 +2044,7 @@ def local_dot22_to_dot22scalar(node):
inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx
] ]
return [T.mul(dot, *(other_factors + other_m_inputs))] return [tt.mul(dot, *(other_factors + other_m_inputs))]
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(node.inputs): for i, x in enumerate(node.inputs):
...@@ -2069,12 +2069,12 @@ def local_dot22_to_dot22scalar(node): ...@@ -2069,12 +2069,12 @@ def local_dot22_to_dot22scalar(node):
o.remove(d) o.remove(d)
o.remove(s) o.remove(s)
a = T.cast(i_scalar[scalar_idx], d.type.dtype) a = tt.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
if len(o) == 0: if len(o) == 0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else: else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] return [tt.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
# must happen after gemm as the gemm optimizer don't understant # must happen after gemm as the gemm optimizer don't understant
...@@ -2094,7 +2094,7 @@ class BatchedDot(Op): ...@@ -2094,7 +2094,7 @@ class BatchedDot(Op):
__props__ = () __props__ = ()
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs)) inputs = list(map(tt.as_tensor_variable, inputs))
if len(inputs) != 2: if len(inputs) != 2:
raise TypeError( raise TypeError(
...@@ -2116,13 +2116,13 @@ class BatchedDot(Op): ...@@ -2116,13 +2116,13 @@ class BatchedDot(Op):
dtype = theano.scalar.upcast(*[input.type.dtype for input in inputs]) dtype = theano.scalar.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed # upcast inputs to common dtype if needed
upcasted_inputs = [T.cast(input, dtype) for input in inputs] upcasted_inputs = [tt.cast(input, dtype) for input in inputs]
broadcastable = ( broadcastable = (
(inputs[0].type.broadcastable[0] or inputs[1].type.broadcastable[0],) (inputs[0].type.broadcastable[0] or inputs[1].type.broadcastable[0],)
+ inputs[0].type.broadcastable[1:-1] + inputs[0].type.broadcastable[1:-1]
+ inputs[1].type.broadcastable[2:] + inputs[1].type.broadcastable[2:]
) )
return Apply(self, upcasted_inputs, [T.tensor(dtype, broadcastable)]) return Apply(self, upcasted_inputs, [tt.tensor(dtype, broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y = inp x, y = inp
...@@ -2459,27 +2459,27 @@ class BatchedDot(Op): ...@@ -2459,27 +2459,27 @@ class BatchedDot(Op):
# x is a matrix, y is a tensor3, grad is a matrix # x is a matrix, y is a tensor3, grad is a matrix
elif xdim == 2 and ydim == 3: elif xdim == 2 and ydim == 3:
xgrad = T.batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = tt.batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1) ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1)
# x is a tensor3, y is a matrix, grad is a matrix # x is a tensor3, y is a matrix, grad is a matrix
elif xdim == 3 and ydim == 2: elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1) xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1)
ygrad = T.batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = tt.batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is a tensor3, y is a tensor3, grad is a tensor3 # x is a tensor3, y is a tensor3, grad is a tensor3
elif xdim == ydim == 3: elif xdim == ydim == 3:
xgrad = T.batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = tt.batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = T.batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = tt.batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of # If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the # them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern. # above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461. # This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable: if xgrad.broadcastable != x.broadcastable:
xgrad = T.patternbroadcast(xgrad, x.broadcastable) xgrad = tt.patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable: if ygrad.broadcastable != y.broadcastable:
ygrad = T.patternbroadcast(ygrad, y.broadcastable) ygrad = tt.patternbroadcast(ygrad, y.broadcastable)
return xgrad, ygrad return xgrad, ygrad
...@@ -2573,7 +2573,7 @@ batched_dot = BatchedDot() ...@@ -2573,7 +2573,7 @@ batched_dot = BatchedDot()
# from opt import register_specialize, register_canonicalize # from opt import register_specialize, register_canonicalize
# @register_specialize # @register_specialize
@local_optimizer([T.sub, T.add]) @local_optimizer([tt.sub, tt.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (tt.sub, tt.add):
debugprint(node) debugprint(node)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论