提交 5d1b70a1 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

changed comments related to deprecated Value class

上级 9c3f95cc
...@@ -44,7 +44,7 @@ class OpFromGraph(gof.Op): ...@@ -44,7 +44,7 @@ class OpFromGraph(gof.Op):
if 'updates' in kwargs: if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and # TODO: the graph may have implicit inputs like
# SharedVariable instances. # SharedVariable instances.
# what impact to they have on the validity of this Op? # what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs) self.fn = orig_function(inputs, outputs, **kwargs)
......
...@@ -477,10 +477,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -477,10 +477,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant): if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param) raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value): if isinstance(param, Variable): # N.B. includes SharedVariable
#return In(variable=param)
#raise NotImplementedError()
if isinstance(param, Variable): # N.B. includes Value and SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
......
...@@ -28,7 +28,7 @@ class Env(utils.object2): ...@@ -28,7 +28,7 @@ class Env(utils.object2):
""" WRITEME """ WRITEME
An Env represents a subgraph bound by a set of input variables and a An Env represents a subgraph bound by a set of input variables and a
set of output variables. The inputs list should contain all the inputs set of output variables. The inputs list should contain all the inputs
on which the outputs depend. Variables of type Value or Constant are on which the outputs depend. Variables of type Constant are
not counted as inputs. not counted as inputs.
The Env supports the replace operation which allows to replace a The Env supports the replace operation which allows to replace a
......
...@@ -217,8 +217,6 @@ class Variable(utils.object2): ...@@ -217,8 +217,6 @@ class Variable(utils.object2):
- `Variable` (this base type) is typically the output of a symbolic computation, - `Variable` (this base type) is typically the output of a symbolic computation,
- `Value` (a subclass) adds a default :literal:`value`, and requires that owner is None
- `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and - `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and
requires that owner is None requires that owner is None
...@@ -396,6 +394,7 @@ class Constant(Value): ...@@ -396,6 +394,7 @@ class Constant(Value):
if len(name) > 20: if len(name) > 20:
name = name[:10] + '...' + name[-10] name = name[:10] + '...' + name[-10]
return 'Constant{%s}' % name return 'Constant{%s}' % name
def clone(self): def clone(self):
""" """
We clone this object, but we don't clone the data to lower memory requirement We clone this object, but we don't clone the data to lower memory requirement
......
...@@ -212,17 +212,6 @@ def constant(x, name=None): ...@@ -212,17 +212,6 @@ def constant(x, name=None):
except TypeError: except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x)) raise TypeError("Could not convert %s to SparseType" % x, type(x))
if 0:
def value(x):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.value must be called on a "
"scipy.sparse.spmatrix")
try:
return SparseValue(SparseType(format=x.format,
dtype=x.dtype), x)
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x): def sp_ones_like(x):
# TODO: don't restrict to CSM formats # TODO: don't restrict to CSM formats
...@@ -760,19 +749,19 @@ class CSMGrad(gof.op.Op): ...@@ -760,19 +749,19 @@ class CSMGrad(gof.op.Op):
sp_dim = x_shape[1] sp_dim = x_shape[1]
else: else:
sp_dim = x_shape[0] sp_dim = x_shape[0]
g_row = numpy.zeros(sp_dim, dtype=g_data.dtype) g_row = numpy.zeros(sp_dim, dtype=g_data.dtype)
gout_data = numpy.zeros_like(x_data) gout_data = numpy.zeros_like(x_data)
for i in range(len(x_indptr) - 1): for i in range(len(x_indptr) - 1):
for j_ptr in range(g_indptr[i], g_indptr[i + 1]): for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] += g_data[j_ptr] g_row[g_indices[j_ptr]] += g_data[j_ptr]
for j_ptr in range(x_indptr[i], x_indptr[i + 1]): for j_ptr in range(x_indptr[i], x_indptr[i + 1]):
gout_data[j_ptr] = g_row[x_indices[j_ptr]] gout_data[j_ptr] = g_row[x_indices[j_ptr]]
for j_ptr in range(g_indptr[i], g_indptr[i + 1]): for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] = 0 g_row[g_indices[j_ptr]] = 0
if self.kmap is None: if self.kmap is None:
g_out[0] = gout_data g_out[0] = gout_data
else: else:
...@@ -811,7 +800,7 @@ class CSMGradC(gof.Op): ...@@ -811,7 +800,7 @@ class CSMGradC(gof.Op):
raise NotImplementedError('Complex types are not supported for a_val') raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b_val') raise NotImplementedError('Complex types are not supported for b_val')
return """ return """
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
...@@ -825,22 +814,22 @@ class CSMGradC(gof.Op): ...@@ -825,22 +814,22 @@ class CSMGradC(gof.Op):
if (%(a_ptr)s->descr->type_num != PyArray_INT32) if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(b_ind)s->descr->type_num != PyArray_INT32) { if (%(b_ind)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;}
if (%(b_ptr)s->descr->type_num != PyArray_INT32) if (%(b_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0]) if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0]) if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0]) if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;}
if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0])) if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0]))
{ {
{Py_XDECREF(%(z)s);} {Py_XDECREF(%(z)s);}
...@@ -854,9 +843,9 @@ class CSMGradC(gof.Op): ...@@ -854,9 +843,9 @@ class CSMGradC(gof.Op):
npy_intp M = %(a_ptr)s->dimensions[0] - 1; npy_intp M = %(a_ptr)s->dimensions[0] - 1;
npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0]; npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0];
npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1]; npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1];
npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0;
// strides tell you how many bytes to skip to go to next column/row entry // strides tell you how many bytes to skip to go to next column/row entry
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize; npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize;
npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize; npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize;
...@@ -876,9 +865,9 @@ class CSMGradC(gof.Op): ...@@ -876,9 +865,9 @@ class CSMGradC(gof.Op):
const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data; const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data;
npy_intp nnz = %(a_ind)s->dimensions[0]; npy_intp nnz = %(a_ind)s->dimensions[0];
dtype_%(b_val)s b_row[sp_dim]; dtype_%(b_val)s b_row[sp_dim];
//clear the output array //clear the output array
for (npy_int64 i = 0; i < nnz; ++i) for (npy_int64 i = 0; i < nnz; ++i)
{ {
...@@ -893,12 +882,12 @@ class CSMGradC(gof.Op): ...@@ -893,12 +882,12 @@ class CSMGradC(gof.Op):
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val]; b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val];
} }
for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr]; for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr];
j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) { j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) {
Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]]; Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]];
} }
for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr];
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] = 0; b_row[Db_ind[j_ptr * Sb_ind]] = 0;
......
...@@ -2278,7 +2278,7 @@ class MaxAndArgmax(Op): ...@@ -2278,7 +2278,7 @@ class MaxAndArgmax(Op):
# not calculated here for it is not defined at every point where some # not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null # coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient. # Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s. # @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw) # (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
...@@ -2314,7 +2314,7 @@ class MaxAndArgmax(Op): ...@@ -2314,7 +2314,7 @@ class MaxAndArgmax(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论