提交 3e146aa3 authored 作者: James Bergstra's avatar James Bergstra

test of elemwise doc

上级 8c8c4b6e
...@@ -288,6 +288,9 @@ class Elemwise(Op): ...@@ -288,6 +288,9 @@ class Elemwise(Op):
else: else:
return self.name return self.name
def __repr__(self):
return self.__str__()
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
ograds = map(as_tensor, ograds) # this shouldn't be necessary... ograds = map(as_tensor, ograds) # this shouldn't be necessary...
scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs] scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
......
...@@ -664,6 +664,24 @@ def argmax(x, axis=None): ...@@ -664,6 +664,24 @@ def argmax(x, axis=None):
# Comparison # Comparison
########################## ##########################
def _elemwise_macro(scalar_op, *args):
straight = elemwise.Elemwise(scalar_op)
return straight(*args)
def _elemwise_macro_inplace(scalar_op, *args):
#construct an inplace version of the scalar op
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0})
return inplace(*args)
def lt(a, b):
"""asdfasdf"""
return _elemwise_macro(scal.lt, a, b)
def _lt_inplace(a,b):
"""asdfasdf inplace!"""
return _elemwise_macro_inplace(scal.lt, a, b)
lt, _lt_inplace = _elemwise(scal.lt, 'lt', lt, _lt_inplace = _elemwise(scal.lt, 'lt',
"""less than (elemwise)""") """less than (elemwise)""")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论