提交 042e4a67 authored 作者: Larry Dong's avatar Larry Dong 提交者: Brandon T. Willard

change at.diff to slicing + subtraction

上级 e5ebf260
...@@ -3,6 +3,7 @@ from typing import Iterable, Tuple, Union ...@@ -3,6 +3,7 @@ from typing import Iterable, Tuple, Union
import numpy as np import numpy as np
import numpy.core.numeric import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index
import aesara import aesara
from aesara.gradient import ( from aesara.gradient import (
...@@ -483,7 +484,7 @@ class DiffOp(Op): ...@@ -483,7 +484,7 @@ class DiffOp(Op):
def make_node(self, x): def make_node(self, x):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
axis = numpy.core.numeric.normalize_axis_index(self.axis, x.ndim) axis = normalize_axis_index(self.axis, x.ndim)
shape = [None] * x.type.ndim shape = [None] * x.type.ndim
for i, shape_i in enumerate(x.type.shape): for i, shape_i in enumerate(x.type.shape):
if shape_i is None: if shape_i is None:
...@@ -533,7 +534,7 @@ def diff(x, n=1, axis=-1): ...@@ -533,7 +534,7 @@ def diff(x, n=1, axis=-1):
The first order difference is given by ``out[i] = a[i + 1] - a[i]`` The first order difference is given by ``out[i] = a[i + 1] - a[i]``
along the given `axis`, higher order differences are calculated by along the given `axis`, higher order differences are calculated by
using `diff` recursively. This wraps ``numpy.diff``. using `diff` recursively. This is heavily inspired by ``numpy.diff``.
Parameters Parameters
---------- ----------
...@@ -548,7 +549,20 @@ def diff(x, n=1, axis=-1): ...@@ -548,7 +549,20 @@ def diff(x, n=1, axis=-1):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
return DiffOp(n=n, axis=axis)(x) ndim = x.ndim
axis = normalize_axis_index(axis, ndim)
slice1 = [slice(None)] * ndim
slice2 = [slice(None)] * ndim
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
for _ in range(n):
x = x[slice1] - x[slice2]
return x
def bincount(x, weights=None, minlength=None, assert_nonneg=False): def bincount(x, weights=None, minlength=None, assert_nonneg=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论