提交 76ae1047 authored 作者: Frederic's avatar Frederic

Allow BNComposite to be inplace

上级 ea4e4147
...@@ -3329,6 +3329,8 @@ class Composite(ScalarOp): ...@@ -3329,6 +3329,8 @@ class Composite(ScalarOp):
Composite depends on all the Ops in its graph having C code. Composite depends on all the Ops in its graph having C code.
""" """
__props__ = ('inputs', 'outputs')
def __str__(self): def __str__(self):
return self.name return self.name
...@@ -3339,7 +3341,8 @@ class Composite(ScalarOp): ...@@ -3339,7 +3341,8 @@ class Composite(ScalarOp):
This fct allow fix patch this. This fct allow fix patch this.
""" """
out = self.__class__(self.inputs, self.outputs) d = dict([(k, getattr(self, k)) for k in self.__props__])
out = self.__class__(**d)
if name: if name:
out.name = name out.name = name
else: else:
......
...@@ -4,8 +4,10 @@ from theano.scalar import add, sub, true_div, mul ...@@ -4,8 +4,10 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite): class BNComposite(Composite):
__props__ = ('dtype',)
def __init__(self, dtype): def __init__(self, dtype):
self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable() x = theano.scalar.Scalar(dtype=dtype).make_variable()
mean = theano.scalar.Scalar(dtype=dtype).make_variable() mean = theano.scalar.Scalar(dtype=dtype).make_variable()
std = theano.scalar.Scalar(dtype=dtype).make_variable() std = theano.scalar.Scalar(dtype=dtype).make_variable()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论