提交 9088d234 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

scalar.Scalar.desc, scalar.constant, scalar.Composite.impl

上级 ebef4846
...@@ -21,6 +21,11 @@ def as_scalar(x, name = None): ...@@ -21,6 +21,11 @@ def as_scalar(x, name = None):
if isinstance(x, Scalar): if isinstance(x, Scalar):
return x return x
def constant(x):
res = as_scalar(x)
res.constant = True
return res
class Scalar(Result): class Scalar(Result):
...@@ -40,7 +45,10 @@ class Scalar(Result): ...@@ -40,7 +45,10 @@ class Scalar(Result):
self._constant = value self._constant = value
constant = property(__get_constant, __set_constant) constant = property(__get_constant, __set_constant)
def desc(self):
return (self.dtype, self.data)
def filter(self, data): def filter(self, data):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
return py_type(data) return py_type(data)
...@@ -475,6 +483,12 @@ def composite(inputs, outputs): ...@@ -475,6 +483,12 @@ def composite(inputs, outputs):
for output, impl in zip(self.outputs, _impls): for output, impl in zip(self.outputs, _impls):
output.data = impl(inputs) output.data = impl(inputs)
def impl(self, *inputs):
for r, input in zip(self.inputs, inputs):
r.data = input
self.perform()
return utils.to_return_values([output.data for output in self.outputs])
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite") raise NotImplementedError("grad is not implemented for Composite")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论