提交 c4c90d32 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Added corrections to ops.py

上级 ab9078d6
...@@ -9,6 +9,7 @@ import warnings ...@@ -9,6 +9,7 @@ import warnings
import theano import theano
from theano import gof from theano import gof
from theano.compat import OrderedDict
from six import iteritems from six import iteritems
from six.moves import xrange from six.moves import xrange
...@@ -138,9 +139,6 @@ class DeepCopyOp(gof.Op): ...@@ -138,9 +139,6 @@ class DeepCopyOp(gof.Op):
def __init__(self): def __init__(self):
pass pass
def __str__(self):
return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
...@@ -220,9 +218,6 @@ class Shape(gof.Op): ...@@ -220,9 +218,6 @@ class Shape(gof.Op):
check_input = False check_input = False
__props__ = () __props__ = ()
def __str__(self):
return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
# Must work for all type that have a shape attribute. # Must work for all type that have a shape attribute.
# This will fail at execution time. # This will fail at execution time.
...@@ -609,10 +604,12 @@ class Rebroadcast(gof.Op): ...@@ -609,10 +604,12 @@ class Rebroadcast(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ("axis") __props__ = ("axis",)
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) # Sort them to make sure we merge all possible case.
items = sorted(iteritems(self.axis))
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis): for axis, broad in iteritems(self.axis):
assert isinstance(axis, (numpy.integer, int)), ( assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast needs integer axes. Got ", axis) "Rebroadcast needs integer axes. Got ", axis)
...@@ -749,9 +746,6 @@ class SpecifyShape(gof.Op): ...@@ -749,9 +746,6 @@ class SpecifyShape(gof.Op):
c_code_and_version = {} c_code_and_version = {}
__props__ = () __props__ = ()
def __str__(self):
return self.__class__.__name__
def make_node(self, x, shape): def make_node(self, x, shape):
if not isinstance(x, gof.Variable): if not isinstance(x, gof.Variable):
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论