提交 f9b18d2a authored 作者: James Bergstra's avatar James Bergstra

added scalar_from_tensor, added support to Subtensor to index by 0-d tensors

上级 529bb53d
......@@ -463,9 +463,21 @@ class TensorFromScalar(Op):
def perform(self):
self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)):
raise NotImplementedError('todo: ScalarFromTensor')
return [ScalarFromTensor(dt)]
tensor_from_scalar = gof.op.constructor(TensorFromScalar)
class ScalarFromTensor(Op):
def __init__(self, s, **kwargs):
assert isinstance(s, Tensor)
Op.__init__(self, **kwargs)
self.inputs = [s]
self.outputs = [scal.Scalar(s.dtype)]
def perform(self):
self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)):
return [TensorFromScalar(dt)]
scalar_from_tensor = gof.op.constructor(ScalarFromTensor)
##########################
# Unary Operations
##########################
......@@ -673,6 +685,8 @@ class Subtensor(Op, Viewer):
def asidx(i):
if isinstance(i, int): return scal.constant(i)
if isinstance(i, scal.Scalar) and ('int' in i.dtype): return i
if isinstance(i, Tensor) and ('int' in i.dtype) and (i.ndim == 0):
return scalar_from_tensor(i)
raise TypeError(Subtensor.e_invalid, i)
x = _as_tensor(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论