提交 19092641 authored 作者: Frederic's avatar Frederic

pep8 fix.

上级 fce0a1e8
from theano.sparse.basic import * # To facilitate later merge into sparse module from theano.sparse.basic import * # To facilitate later merge into sparse module
from theano.sparse.basic import _is_sparse, _is_sparse_variable, \ from theano.sparse.basic import (
_is_dense_variable, _is_sparse, _is_dense, _kmap_eq, _kmap_hash _is_sparse, _is_sparse_variable, _is_dense_variable,
_is_sparse, _is_dense, _kmap_eq, _kmap_hash)
class Cast(gof.op.Op): class Cast(gof.op.Op):
def __init__(self, out_type): def __init__(self, out_type):
self.out_type = out_type self.out_type = out_type
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and self.out_type == other.out_type return (type(self) == type(other)) and self.out_type == other.out_type
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.out_type) return hash(type(self)) ^ hash(self.out_type)
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply(self, [x], return gof.Apply(self, [x],
[SparseType(dtype=self.out_type, format=x.format).make_variable()]) [SparseType(dtype=self.out_type, format=x.format).make_variable()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x out[0] = x
...@@ -20,95 +26,118 @@ class Cast(gof.op.Op): ...@@ -20,95 +26,118 @@ class Cast(gof.op.Op):
fcast = Cast('float32') fcast = Cast('float32')
dcast = Cast('float64') dcast = Cast('float64')
class Poisson(gof.op.Op): class Poisson(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x.copy() out[0] = x.copy()
out[0].data = numpy.asarray(numpy.random.poisson(out[0].data), dtype=x.dtype) out[0].data = numpy.asarray(numpy.random.poisson(out[0].data),
dtype=x.dtype)
out[0].eliminate_zeros() out[0].eliminate_zeros()
poisson = Poisson() poisson = Poisson()
class Multinomial(gof.op.Op): class Multinomial(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, n, p): def make_node(self, n, p):
n = tensor.as_tensor_variable(n) n = tensor.as_tensor_variable(n)
p = as_sparse_variable(p) p = as_sparse_variable(p)
return gof.Apply(self, [n, p], [p.type()]) return gof.Apply(self, [n, p], [p.type()])
def perform(self, node, (n, p), (out, )): def perform(self, node, (n, p), (out, )):
assert _is_sparse(p) assert _is_sparse(p)
if p.format != 'csr': if p.format != 'csr':
raise NotImplemented() raise NotImplemented()
out[0] = p.copy() out[0] = p.copy()
for i in xrange(p.shape[0]): for i in xrange(p.shape[0]):
k, l = p.indptr[i], p.indptr[i+1] k, l = p.indptr[i], p.indptr[i + 1]
out[0].data[k:l] = numpy.random.multinomial(n[i], p.data[k:l]) out[0].data[k:l] = numpy.random.multinomial(n[i], p.data[k:l])
multinomial = Multinomial() multinomial = Multinomial()
class EliminateZeros(gof.op.Op): class EliminateZeros(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x.copy() out[0] = x.copy()
out[0].eliminate_zeros() out[0].eliminate_zeros()
eliminate_zeros = EliminateZeros() eliminate_zeros = EliminateZeros()
class Sum(gof.op.Op): class Sum(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x, a): def make_node(self, x, a):
x = as_sparse_variable(x) x = as_sparse_variable(x)
a = tensor.as_tensor_variable(a) a = tensor.as_tensor_variable(a)
return gof.Apply(self, [x, a], [tensor.TensorType(dtype = x.type.dtype, return gof.Apply(self, [x, a], [tensor.TensorType(dtype=x.type.dtype,
broadcastable = (False,)).make_variable()]) broadcastable=(False,)).make_variable()])
def perform(self, node, (x, a), (out, )): def perform(self, node, (x, a), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = numpy.asarray(x.sum(a), dtype=x.dtype).flatten() out[0] = numpy.asarray(x.sum(a), dtype=x.dtype).flatten()
sum = Sum() sum = Sum()
class Binomial(gof.op.Op): class Binomial(gof.op.Op):
def __init__(self, format, dtype): def __init__(self, format, dtype):
self.format = format self.format = format
self.dtype = dtype self.dtype = dtype
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and self.format == other.format and \ return ((type(self) == type(other)) and
self.dtype == other.dtype self.format == other.format and
self.dtype == other.dtype)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.format) ^ hash(self.dtype) return hash(type(self)) ^ hash(self.format) ^ hash(self.dtype)
def make_node(self, n, p, shape): def make_node(self, n, p, shape):
n = tensor.as_tensor_variable(n) n = tensor.as_tensor_variable(n)
p = tensor.as_tensor_variable(p) p = tensor.as_tensor_variable(p)
shape = tensor.as_tensor_variable(shape) shape = tensor.as_tensor_variable(shape)
return gof.Apply(self, [n, p, shape], [SparseType(dtype = self.dtype, return gof.Apply(self, [n, p, shape], [SparseType(dtype=self.dtype,
format = self.format).make_variable()]) format=self.format).make_variable()])
def perform(self, node, (n, p, shape, ), (out, )): def perform(self, node, (n, p, shape, ), (out, )):
N = n * p * shape[0] * shape[1] N = n * p * shape[0] * shape[1]
data = numpy.ones(N, dtype=self.dtype) data = numpy.ones(N, dtype=self.dtype)
row = numpy.random.randint(0, shape[0], N) row = numpy.random.randint(0, shape[0], N)
col = numpy.random.randint(0, shape[1], N) col = numpy.random.randint(0, shape[1], N)
res = scipy.sparse.coo_matrix((data, (row, col)), shape=shape) res = scipy.sparse.coo_matrix((data, (row, col)), shape=shape)
out[0] = getattr(res, 'to' + self.format)() out[0] = getattr(res, 'to' + self.format)()
out[0].data = numpy.ones_like(out[0].data) out[0].data = numpy.ones_like(out[0].data)
csr_fbinomial = Binomial('csr', 'float32') csr_fbinomial = Binomial('csr', 'float32')
...@@ -116,16 +145,17 @@ csc_fbinomial = Binomial('csc', 'float32') ...@@ -116,16 +145,17 @@ csc_fbinomial = Binomial('csc', 'float32')
csr_dbinomial = Binomial('csr', 'float64') csr_dbinomial = Binomial('csr', 'float64')
csc_dbinomial = Binomial('csc', 'float64') csc_dbinomial = Binomial('csc', 'float64')
def structured_sigmoid(x): def structured_sigmoid(x):
""" """
Element-wise sigmoid function only to the non-zero elements. Element-wise sigmoid function only to the non-zero elements.
""" """
x = as_sparse_variable(x) x = as_sparse_variable(x)
x_data, x_ind, x_ptr, x_shape = csm_properties(x) x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.nnet.sigmoid(x_data) x_data = tensor.nnet.sigmoid(x_data)
return CSR(x_data, x_ind, x_ptr, x_shape) return CSR(x_data, x_ind, x_ptr, x_shape)
...@@ -134,11 +164,11 @@ def structured_exp(x): ...@@ -134,11 +164,11 @@ def structured_exp(x):
Element-wise exponential function to the non-zero elements. Element-wise exponential function to the non-zero elements.
""" """
x = as_sparse_variable(x) x = as_sparse_variable(x)
x_data, x_ind, x_ptr, x_shape = csm_properties(x) x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.exp(x_data) x_data = tensor.exp(x_data)
return CSR(x_data, x_ind, x_ptr, x_shape) return CSR(x_data, x_ind, x_ptr, x_shape)
...@@ -162,13 +192,13 @@ def structured_minimum(x, y): ...@@ -162,13 +192,13 @@ def structured_minimum(x, y):
Element-wise minimum function only to non-zero elements. Element-wise minimum function only to non-zero elements.
""" """
x = as_sparse_variable(x) x = as_sparse_variable(x)
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
x_data, x_ind, x_ptr, x_shape = csm_properties(x) x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.minimum(x_data, y) x_data = tensor.minimum(x_data, y)
return CSR(x_data, x_ind, x_ptr, x_shape) return CSR(x_data, x_ind, x_ptr, x_shape)
...@@ -179,24 +209,28 @@ class StructuredAddSV(gof.op.Op): ...@@ -179,24 +209,28 @@ class StructuredAddSV(gof.op.Op):
matrix.''' matrix.'''
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x = as_sparse_variable(x) x = as_sparse_variable(x)
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
assert y.type.ndim == 1 assert y.type.ndim == 1
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
raise NotImplementedError() raise NotImplementedError()
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype = x.type.dtype, [SparseType(dtype=x.type.dtype,
format = x.type.format).make_variable()]) format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )): def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and not _is_sparse(y) assert _is_sparse(x) and not _is_sparse(y)
assert x.shape[1] == y.shape[0] assert x.shape[1] == y.shape[0]
out[0] = x.__class__(x + (x.toarray() != 0) * y) out[0] = x.__class__(x + (x.toarray() != 0) * y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(y) assert _is_sparse_variable(x) and _is_sparse_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
...@@ -207,14 +241,18 @@ structured_add_s_v = StructuredAddSV() ...@@ -207,14 +241,18 @@ structured_add_s_v = StructuredAddSV()
class StrucutedAddSVCSR(gof.Op): class StrucutedAddSVCSR(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b): def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 1 assert b.type.ndim == 1
return gof.Apply(self, [a_data, a_indices, a_indptr, b], return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub):
def c_code(self, node, name, inputs, outputs, sub):
_data, _indices, _indptr, _b, = inputs
_zout, = outputs
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a') raise NotImplementedError('Complex types are not supported for a')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
...@@ -272,98 +310,105 @@ class StrucutedAddSVCSR(gof.Op): ...@@ -272,98 +310,105 @@ class StrucutedAddSVCSR(gof.Op):
} }
} }
"""% dict(locals(), **sub) """ % dict(locals(), **sub)
structured_add_s_v_csr = StrucutedAddSVCSR() structured_add_s_v_csr = StrucutedAddSVCSR()
@gof.local_optimizer([structured_add_s_v]) @gof.local_optimizer([structured_add_s_v])
def local_structured_add_s_v(node): def local_structured_add_s_v(node):
if node.op == structured_add_s_v: if node.op == structured_add_s_v:
x, y = node.inputs x, y = node.inputs
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) #y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable: if x_is_sparse_variable:
svar = x svar = x
dvar = y dvar = y
else: else:
svar = y svar = y
dvar = x dvar = x
if dvar.type.ndim != 1: if dvar.type.ndim != 1:
return False return False
elif svar.type.format == 'csr': elif svar.type.format == 'csr':
CSx = CSR CSx = CSR
structured_add_s_v_csx = structured_add_s_v_csr structured_add_s_v_csx = structured_add_s_v_csr
else: else:
raise NotImplemented() raise NotImplemented()
s_val, s_ind, s_ptr, s_shape = csm_properties(svar) s_val, s_ind, s_ptr, s_shape = csm_properties(svar)
c_data = structured_add_s_v_csx(s_val, s_ind, s_ptr, dvar) c_data = structured_add_s_v_csx(s_val, s_ind, s_ptr, dvar)
return [CSx(c_data, s_ind, s_ptr, s_shape)] return [CSx(c_data, s_ind, s_ptr, s_shape)]
return False return False
register_specialize(local_structured_add_s_v) register_specialize(local_structured_add_s_v)
class SamplingDot(gof.op.Op): class SamplingDot(gof.op.Op):
""" """
Operand for calculating the dot product DOT(X, Y) = Z when you only want to calculate Operand for calculating the dot product DOT(X, Y) = Z when you
a subset of Z. It is equivalent to P o (X . Y) where o is the element-wise product, X and Y operands of only want to calculate a subset of Z. It is equivalent to P o (X
the dot product and P is a matrix that contains 1 when the corresponding element of Z should be calculated . Y) where o is the element-wise product, X and Y operands of the
and 0 when it shouldn't. Note that SamplingDot has a different interface than DOT because SamplingDot dot product and P is a matrix that contains 1 when the
requires X to be a MxK matrix while Y is a NxK matrix instead of the usual KxN matrix. corresponding element of Z should be calculated and 0 when it
shouldn't. Note that SamplingDot has a different interface than
It will work if the pattern is not binary value, but if the pattern doesn't have a high sparsity proportion DOT because SamplingDot requires X to be a MxK matrix while Y is a
it will be slower then a more optimized dot followed by a normal elemwise multiplication. NxK matrix instead of the usual KxN matrix.
It will work if the pattern is not binary value, but if the
pattern doesn't have a high sparsity proportion it will be slower
then a more optimized dot followed by a normal elemwise
multiplication.
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __str__(self): def __str__(self):
return 'SamplingDot' return 'SamplingDot'
def make_node(self, x, y, p): def make_node(self, x, y, p):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
if not _is_sparse_variable(p): if not _is_sparse_variable(p):
raise TypeError(p) raise TypeError(p)
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype)
return gof.Apply(self, [x, y, p], [p.type()]) return gof.Apply(self, [x, y, p], [p.type()])
def perform(self, node, (x, y, p), (out,)): def perform(self, node, (x, y, p), (out,)):
if _is_sparse_variable(x): if _is_sparse_variable(x):
raise TypeError(x) raise TypeError(x)
if _is_sparse_variable(y): if _is_sparse_variable(y):
raise TypeError(y) raise TypeError(y)
if not _is_sparse(p): if not _is_sparse(p):
raise TypeError(p) raise TypeError(p)
rval = p.__class__(p.multiply(numpy.dot(x, y.T))) rval = p.__class__(p.multiply(numpy.dot(x, y.T)))
out[0] = rval out[0] = rval
def grad(self, (x, y, p), (gz,)): def grad(self, (x, y, p), (gz,)):
rval = [ rval = [
dot(gz, y), dot(gz, y),
dot(gz.T, x), dot(gz.T, x),
None None
] ]
return rval return rval
sampling_dot = SamplingDot() sampling_dot = SamplingDot()
class SamplingDotCsr(gof.Op): class SamplingDotCsr(gof.Op):
""" """
Optimized SamplingDot when the pattern P is a CSR matrix. Optimized SamplingDot when the pattern P is a CSR matrix.
...@@ -374,13 +419,13 @@ class SamplingDotCsr(gof.Op): ...@@ -374,13 +419,13 @@ class SamplingDotCsr(gof.Op):
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __str__(self): def __str__(self):
return 'SamplingDot{Csr}' return 'SamplingDot{Csr}'
def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
...@@ -388,12 +433,13 @@ class SamplingDotCsr(gof.Op): ...@@ -388,12 +433,13 @@ class SamplingDotCsr(gof.Op):
p_ind = tensor.as_tensor_variable(p_ind) p_ind = tensor.as_tensor_variable(p_ind)
p_ptr = tensor.as_tensor_variable(p_ptr) p_ptr = tensor.as_tensor_variable(p_ptr)
p_ncols = tensor.as_tensor_variable(p_ncols) p_ncols = tensor.as_tensor_variable(p_ncols)
assert p_ncols.dtype == 'int32' assert p_ncols.dtype == 'int32'
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p_data.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype,
p_data.type.dtype)
dot_out = scalar.upcast(x.type.dtype, y.type.dtype) dot_out = scalar.upcast(x.type.dtype, y.type.dtype)
# We call blas ?dot function that take only param of the same type # We call blas ?dot function that take only param of the same type
x = tensor.cast(x, dot_out) x = tensor.cast(x, dot_out)
y = tensor.cast(y, dot_out) y = tensor.cast(y, dot_out)
...@@ -406,7 +452,7 @@ class SamplingDotCsr(gof.Op): ...@@ -406,7 +452,7 @@ class SamplingDotCsr(gof.Op):
def c_support_code(self): def c_support_code(self):
return blas.blas_header_text() return blas.blas_header_text()
def c_libraries(self): def c_libraries(self):
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
return blas.ldflags() return blas.ldflags()
...@@ -416,19 +462,24 @@ class SamplingDotCsr(gof.Op): ...@@ -416,19 +462,24 @@ class SamplingDotCsr(gof.Op):
def c_lib_dirs(self): def c_lib_dirs(self):
return blas.ldflags(libs=False, libs_dir=True) return blas.ldflags(libs=False, libs_dir=True)
def c_header_dirs(self): def c_header_dirs(self):
return blas.ldflags(libs=False, include_dir=True) return blas.ldflags(libs=False, include_dir=True)
def c_code(self, node, name, (x, y, p_data, p_ind, p_ptr, p_ncols), (z_data, z_ind, z_ptr), sub): def c_code(self, node, name, inputs, outputs, sub):
x, y, p_data, p_ind, p_ptr, p_ncols = inputs
z_data, z_ind, z_ptr = outputs
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x') raise NotImplementedError('Complex types are not supported for x')
if node.inputs[1].type.dtype in ('complex64', 'complex128'): if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y') raise NotImplementedError('Complex types are not supported for y')
if node.inputs[2].type.dtype in ('complex64', 'complex128'): if node.inputs[2].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for pattern') raise NotImplementedError(
'Complex types are not supported for pattern')
dot_out = scalar.upcast(node.inputs[0].type.dtype, node.inputs[0].type.dtype)
# TODO: why 2 times the same inputs?
dot_out = scalar.upcast(node.inputs[0].type.dtype,
node.inputs[0].type.dtype)
if dot_out == "float32": if dot_out == "float32":
conv_type = "float" conv_type = "float"
...@@ -436,13 +487,17 @@ class SamplingDotCsr(gof.Op): ...@@ -436,13 +487,17 @@ class SamplingDotCsr(gof.Op):
else: else:
conv_type = "double" conv_type = "double"
cdot = "ddot_sub_" cdot = "ddot_sub_"
typenum_x = node.inputs[0].type.dtype_specs()[-1] # retrieve dtype number # retrieve dtype number
typenum_y = node.inputs[1].type.dtype_specs()[-1] # retrieve dtype number typenum_x = node.inputs[0].type.dtype_specs()[-1]
typenum_p = node.inputs[2].type.dtype_specs()[-1] # retrieve dtype number typenum_y = node.inputs[1].type.dtype_specs()[-1]
typenum_zd = tensor.TensorType(node.outputs[0].dtype, []).dtype_specs()[-1] # retrieve dtype number typenum_p = node.inputs[2].type.dtype_specs()[-1]
typenum_zi = tensor.TensorType(node.outputs[1].dtype, []).dtype_specs()[-1] # retrieve dtype number typenum_zd = tensor.TensorType(node.outputs[0].dtype,
typenum_zp = tensor.TensorType(node.outputs[2].dtype, []).dtype_specs()[-1] # retrieve dtype number []).dtype_specs()[-1]
typenum_zi = tensor.TensorType(node.outputs[1].dtype,
[]).dtype_specs()[-1]
typenum_zp = tensor.TensorType(node.outputs[2].dtype,
[]).dtype_specs()[-1]
rval = """ rval = """
if (%(x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} if (%(x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
...@@ -529,13 +584,14 @@ class SamplingDotCsr(gof.Op): ...@@ -529,13 +584,14 @@ class SamplingDotCsr(gof.Op):
Dzd[n_idx * Sdzd] *= Dpd[n_idx * Sdpd]; Dzd[n_idx * Sdzd] *= Dpd[n_idx * Sdpd];
} }
} }
} }
"""% dict(locals(), **sub) """ % dict(locals(), **sub)
return rval return rval
sampling_dot_csr = SamplingDotCsr() sampling_dot_csr = SamplingDotCsr()
# register a specialization to replace SamplingDot -> SamplingDotCsr # register a specialization to replace SamplingDot -> SamplingDotCsr
@gof.local_optimizer([sampling_dot]) @gof.local_optimizer([sampling_dot])
def local_sampling_dot_csr(node): def local_sampling_dot_csr(node):
...@@ -543,10 +599,10 @@ def local_sampling_dot_csr(node): ...@@ -543,10 +599,10 @@ def local_sampling_dot_csr(node):
x, y, p = node.inputs x, y, p = node.inputs
if p.type.format == 'csr': if p.type.format == 'csr':
p_data, p_ind, p_ptr, p_shape = csm_properties(p) p_data, p_ind, p_ptr, p_shape = csm_properties(p)
z_data, z_ind, z_ptr = sampling_dot_csr(x, y, p_data, z_data, z_ind, z_ptr = sampling_dot_csr(x, y, p_data,
p_ind, p_ptr, p_shape[1]) p_ind, p_ptr, p_shape[1])
return [CSR(z_data, z_ind, z_ptr, p_shape)] return [CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
register_specialize(local_sampling_dot_csr, name='local_sampling_dot_csr') register_specialize(local_sampling_dot_csr, name='local_sampling_dot_csr')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论