提交 f9b18d2a authored 作者: James Bergstra's avatar James Bergstra

added scalar_from_tensor, added support to Subtensor to index by 0-d tensors

上级 529bb53d
...@@ -463,9 +463,21 @@ class TensorFromScalar(Op): ...@@ -463,9 +463,21 @@ class TensorFromScalar(Op):
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)): def grad(self, (s,), (dt,)):
raise NotImplementedError('todo: ScalarFromTensor') return [ScalarFromTensor(dt)]
tensor_from_scalar = gof.op.constructor(TensorFromScalar) tensor_from_scalar = gof.op.constructor(TensorFromScalar)
class ScalarFromTensor(Op):
def __init__(self, s, **kwargs):
assert isinstance(s, Tensor)
Op.__init__(self, **kwargs)
self.inputs = [s]
self.outputs = [scal.Scalar(s.dtype)]
def perform(self):
self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)):
return [TensorFromScalar(dt)]
scalar_from_tensor = gof.op.constructor(ScalarFromTensor)
########################## ##########################
# Unary Operations # Unary Operations
########################## ##########################
...@@ -673,6 +685,8 @@ class Subtensor(Op, Viewer): ...@@ -673,6 +685,8 @@ class Subtensor(Op, Viewer):
def asidx(i): def asidx(i):
if isinstance(i, int): return scal.constant(i) if isinstance(i, int): return scal.constant(i)
if isinstance(i, scal.Scalar) and ('int' in i.dtype): return i if isinstance(i, scal.Scalar) and ('int' in i.dtype): return i
if isinstance(i, Tensor) and ('int' in i.dtype) and (i.ndim == 0):
return scalar_from_tensor(i)
raise TypeError(Subtensor.e_invalid, i) raise TypeError(Subtensor.e_invalid, i)
x = _as_tensor(x) x = _as_tensor(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论