提交 b3c674e1 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

factored default update flags out of the method signature

上级 fee6befb
...@@ -3963,12 +3963,17 @@ class Subtensor(Op): ...@@ -3963,12 +3963,17 @@ class Subtensor(Op):
indices.append(str(entry)) indices.append(str(entry))
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices)) return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
@staticmethod
def default_update_flags():
return ("PyArray_UpdateFlags(xview,"
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);")
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, def helper_c_code(node, name, inputs, outputs, sub, idx_list,
c_prefix="PyArray", c_prefix="PyArray",
update_flags=("PyArray_UpdateFlags(xview," update_flags=None,
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);"),
set_data='PyArray_set_data', set_data='PyArray_set_data',
set_dim='PyArray_set_dim', set_dim='PyArray_set_dim',
set_stride='PyArray_set_stride', set_stride='PyArray_set_stride',
...@@ -3979,6 +3984,10 @@ class Subtensor(Op): ...@@ -3979,6 +3984,10 @@ class Subtensor(Op):
function on PyArray and CudaNdarray object. function on PyArray and CudaNdarray object.
""" """
if update_flags is None:
update_flags = Subtensor.default_update_flags()
# #
# two arrays are created in C code: # two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice # is_slice: len == ndim, 0 means int, 1 means slice
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论