提交 eda8fc3c authored 作者: James Bergstra's avatar James Bergstra

added __array_priority__ for Tensor variables.

上级 ded89f6c
...@@ -263,10 +263,12 @@ class TensorType(Type): ...@@ -263,10 +263,12 @@ class TensorType(Type):
return False return False
if 'int' in str(a.dtype): if 'int' in str(a.dtype):
return numpy.all(a==b) return numpy.all(a==b)
else: elif a.shape == (): #for comparing scalars, use broadcasting.
if a.shape == (): #for comparing scalars, use broadcasting.
ones = numpy.ones(2) ones = numpy.ones(2)
return numpy.allclose(ones * a, ones*b) return numpy.allclose(ones * a, ones*b)
#elif str(a.dtype).startswith('complex'):
# print >> sys.stderr, 'WARNING: skipping comparison of complex'
# return True
else: else:
return numpy.allclose(a,b) return numpy.allclose(a,b)
return False return False
...@@ -663,6 +665,10 @@ class _tensor_py_operators: ...@@ -663,6 +665,10 @@ class _tensor_py_operators:
return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L) return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L)
#TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
class TensorVariable(Variable, _tensor_py_operators): class TensorVariable(Variable, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Variable` class.""" """Subclass to add the tensor operators to the basic `Variable` class."""
...@@ -898,7 +904,8 @@ class MaxAndArgmax(Op): ...@@ -898,7 +904,8 @@ class MaxAndArgmax(Op):
# g_max has one less dimension than x, so you need to complete g_max to x's shape # g_max has one less dimension than x, so you need to complete g_max to x's shape
# when axis=0 the broadcasting mechanism does it automatically # when axis=0 the broadcasting mechanism does it automatically
assert axis.data == 0 or axis.data == x.ndim-1 if not ( axis.data == 0 or axis.data == x.ndim-1):
raise NotImplementedError('MaxAndArgmax gradient with axis corresponding to internal dimension')
g_max_pad = shape_padleft(g_max) if axis.data==0 else \ g_max_pad = shape_padleft(g_max) if axis.data==0 else \
shape_padright(g_max) shape_padright(g_max)
xmax = max(x, axis) xmax = max(x, axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论