提交 1c1285ab authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -311,6 +311,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -311,6 +311,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
} }
int shp_el = PyInt_AsLong(shp_el_obj); int shp_el = PyInt_AsLong(shp_el_obj);
Py_DECREF(shp_el_obj);
if (shp_el <= 0) if (shp_el <= 0)
{ {
...@@ -318,9 +319,8 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -318,9 +319,8 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
free(newdims); free(newdims);
return NULL; return NULL;
} }
newdims[i] = shp_el; newdims[i] = shp_el;
total_elements *= newdims[i]; total_elements *= newdims[i];
} }
...@@ -1425,23 +1425,26 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v) ...@@ -1425,23 +1425,26 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
Py_DECREF((PyObject*)rval); Py_DECREF((PyObject*)rval);
return -1; return -1;
} }
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v)) if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v))
{ {
Py_DECREF(viewCopyForComparison); Py_DECREF(viewCopyForComparison);
Py_DECREF((PyObject*)rval); Py_DECREF((PyObject*)rval);
return -1; return -1;
} }
// Check that copy didn't modify shape or strides // Check that copy didn't modify shape or strides
assert (CudaNdarray_EqualAndIgnore(viewCopyForComparison, rval, 1, 1)); assert (CudaNdarray_EqualAndIgnore(viewCopyForComparison, rval, 1, 1));
assert (rval->base == baseSavedForComparison); assert (rval->base == baseSavedForComparison);
assert (rval->dev_structure_fresh); assert (rval->dev_structure_fresh);
// Clean up locally-created references
Py_DECREF((PyObject*)viewCopyForComparison); Py_DECREF((PyObject*)viewCopyForComparison);
Py_DECREF(rval);
return 0; return 0;
} }
PyMappingMethods CudaNdarrayMappingMethods = { PyMappingMethods CudaNdarrayMappingMethods = {
CudaNdarray_len, //lenfunc mp_length; __len__ CudaNdarray_len, //lenfunc mp_length; __len__
......
...@@ -149,6 +149,10 @@ class mrg_uniform_base(Op): ...@@ -149,6 +149,10 @@ class mrg_uniform_base(Op):
return Apply(self, return Apply(self,
[rstate, size], [rstate, size],
[rstate.type(), self.output_type()]) [rstate.type(), self.output_type()])
def grad(self,inputs,ograd):
return [None for i in inputs]
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -717,5 +721,3 @@ def mrg_random_make_inplace(node): ...@@ -717,5 +721,3 @@ def mrg_random_make_inplace(node):
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
optdb.register('random_make_inplace_mrg', opt.in2out(mrg_random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace') optdb.register('random_make_inplace_mrg', opt.in2out(mrg_random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论