提交 0f2fd4b4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a clone method to TensorType and its equivalents.

上级 06cc52d7
...@@ -570,7 +570,7 @@ class Rebroadcast(gof.Op): ...@@ -570,7 +570,7 @@ class Rebroadcast(gof.Op):
def __hash__(self): def __hash__(self):
items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items)) return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
if len(self.axis) == 0: if len(self.axis) == 0:
...@@ -586,10 +586,9 @@ class Rebroadcast(gof.Op): ...@@ -586,10 +586,9 @@ class Rebroadcast(gof.Op):
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())): if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension') raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.__class__(dtype=x.type.dtype, t = x.type.clone(broadcastable=[self.axis.get(i, b)
broadcastable=[self.axis.get(i, b) for i, b in enumerate(
for i, b in enumerate( x.type.broadcastable)])
x.type.broadcastable)])
return gof.Apply(self, [x], [t()]) return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
......
...@@ -71,6 +71,11 @@ class CudaNdarrayType(Type): ...@@ -71,6 +71,11 @@ class CudaNdarrayType(Type):
self.name = name self.name = name
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def clone(self, dtype=None, broadcastable=None):
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(broadcastable, name=self.name, dtype=dtype)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict, return self.filter_inplace(data, None, strict=strict,
allow_downcast=allow_downcast) allow_downcast=allow_downcast)
......
...@@ -607,7 +607,6 @@ class GpuAlloc(HideC, Alloc): ...@@ -607,7 +607,6 @@ class GpuAlloc(HideC, Alloc):
def __init__(self, memset_0=False): def __init__(self, memset_0=False):
"""memset_0 is only an optimized version. True, it mean the """memset_0 is only an optimized version. True, it mean the
value is always 0, so the c code call memset as it is faster. value is always 0, so the c code call memset as it is faster.
""" """
self.memset_0 = memset_0 self.memset_0 = memset_0
......
...@@ -28,6 +28,14 @@ class GpuArrayType(Type): ...@@ -28,6 +28,14 @@ class GpuArrayType(Type):
raise TypeError("Unsupported dtype for %s: %s" % raise TypeError("Unsupported dtype for %s: %s" %
(self.__class__.__name__, self.dtype)) (self.__class__.__name__, self.dtype))
def clone(self, dtype=None, broadcastable=None):
if dtype is None:
dtype = self.dtype
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(dtype=dtype, broadcastable=broadcastable,
name=self.name)
def __str__(self): def __str__(self):
return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable) return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable)
......
...@@ -2403,17 +2403,7 @@ class Alloc(gof.Op): ...@@ -2403,17 +2403,7 @@ class Alloc(gof.Op):
This Op is used to replace fill() during optimizations because after shapes This Op is used to replace fill() during optimizations because after shapes
are lifted, the first argument to fill can often be pruned from the graph. are lifted, the first argument to fill can often be pruned from the graph.
""" """
def __init__(self): __props__ = ()
pass
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, value, *shape): def make_node(self, value, *shape):
v = as_tensor_variable(value) v = as_tensor_variable(value)
......
...@@ -52,6 +52,18 @@ class TensorType(Type): ...@@ -52,6 +52,18 @@ class TensorType(Type):
" AdvancedSubtensor1 sparse_grad. Now use" " AdvancedSubtensor1 sparse_grad. Now use"
" theano.sparse_grad(a_tensor[an_int_vector]).") " theano.sparse_grad(a_tensor[an_int_vector]).")
def clone(self, dtype=None, broadcastable=None):
"""
Return a copy of the type optionally with a new dtype or
broadcastable pattern.
"""
if dtype is None:
dtype = self.dtype
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(dtype, broadcastable, name=self.name,
sparse_grad=self.sparse_grad)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a """Convert `data` to something which can be associated to a
`TensorVariable`. `TensorVariable`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论