提交 27e259fb authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -120,8 +120,8 @@ class Component(object): ...@@ -120,8 +120,8 @@ class Component(object):
Makes an instance of this Component using the mode provided Makes an instance of this Component using the mode provided
and taking the containers in the memo dictionary. and taking the containers in the memo dictionary.
A Component which builds nothing, such as External or A Component which builds nothing, such as External, may return
Temporary, may return None. None.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -250,7 +250,6 @@ class External(_RComponent): ...@@ -250,7 +250,6 @@ class External(_RComponent):
return rval return rval
class Member(_RComponent): class Member(_RComponent):
""" """
Member represents a Result which is a state of a Composite. That Member represents a Result which is a state of a Composite. That
......
...@@ -747,6 +747,7 @@ class Composite(ScalarOp): ...@@ -747,6 +747,7 @@ class Composite(ScalarOp):
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
env = Env(*gof.graph.clone(inputs, outputs)) env = Env(*gof.graph.clone(inputs, outputs))
gof.MergeOptimizer().optimize(env)
inputs, outputs = env.inputs, env.outputs inputs, outputs = env.inputs, env.outputs
for node in env.nodes: for node in env.nodes:
......
...@@ -214,6 +214,9 @@ class Tensor(Type): ...@@ -214,6 +214,9 @@ class Tensor(Type):
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def to_scalar_type(self):
return scal.Scalar(dtype = self.dtype)
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of Tensor""" """Compare True iff other is the same kind of Tensor"""
return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable
......
...@@ -600,6 +600,14 @@ class Elemwise(Op): ...@@ -600,6 +600,14 @@ class Elemwise(Op):
code = "\n".join(self._c_all(node, name, inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
return code return code
# def elemwise_to_scal(env):
# mapping = {}
# inputs = []
# outputs = []
# for node in env.io_toposort():
# if not isinstance(node.op, Elemwise):
# raise TypeError('All ops in the graph must be Elemwise.')
################ ################
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论