提交 067a3631 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use type.filter_variable instead of filter_update

上级 f13ab8b7
...@@ -103,7 +103,7 @@ def rebuild_collect_shared( outputs ...@@ -103,7 +103,7 @@ def rebuild_collect_shared( outputs
# Do not use default_update if a "real" update was # Do not use default_update if a "real" update was
# provided # provided
if v not in update_d: if v not in update_d:
v_update = v.filter_update(v.default_update) v_update = v.type.filter_variable(v.default_update)
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError( raise TypeError(
( 'an update must have the same type as ' ( 'an update must have the same type as '
...@@ -188,8 +188,8 @@ def rebuild_collect_shared( outputs ...@@ -188,8 +188,8 @@ def rebuild_collect_shared( outputs
'expression'), 'expression'),
(store_into, update_d[store_into])) (store_into, update_d[store_into]))
update_val = store_into.filter_update(update_val) # filter_variable ensure smooth conversion of cpu/gpu Types
# typically this might be a cast() update_val = store_into.type.filter_variable(update_val)
if update_val.type != store_into.type: if update_val.type != store_into.type:
err_msg = ( 'an update must have the same type as the ' err_msg = ( 'an update must have the same type as the '
'original shared variable(dest, dest.type, ' 'original shared variable(dest, dest.type, '
......
...@@ -152,28 +152,33 @@ class Apply(utils.object2): ...@@ -152,28 +152,33 @@ class Apply(utils.object2):
:type strict: Bool :type strict: Bool
:param strict: :param strict:
If True, the type fields of all the inputs must be equal to the current ones, and If True, the type fields of all the inputs must be equal
returned outputs are guaranteed to have the same types as self.outputs. If False, to the current ones (or compatible, for instance Tensor /
then there's no guarantee that the clone's outputs will have the same types as CudaNdarray of the same dtype and broadcastable patterns,
self.outputs, and cloning may not even be possible (it depends on the Op). in which case they will be converted into current Type), and
returned outputs are guaranteed to have the same types as
self.outputs. If False, then there's no guarantee that the
clone's outputs will have the same types as self.outputs,
and cloning may not even be possible (it depends on the Op).
:returns: an Apply instance with the same op but different outputs. :returns: an Apply instance with the same op but different outputs.
""" """
remake_node = False remake_node = False
for curr, new in zip(self.inputs, inputs): new_inputs = inputs[:]
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if not curr.type == new.type: if not curr.type == new.type:
if strict: if strict:
raise TypeError("Cannot change the type of this input.", ((curr, curr.type), # If compatible, casts new into curr.type
(new, new.type))) new_inputs[i] = curr.type.filter_variable(new)
else: else:
remake_node = True remake_node = True
if remake_node: if remake_node:
new_node = self.op.make_node(*inputs) new_node = self.op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag) new_node.tag = copy(self.tag).__update__(new_node.tag)
else: else:
new_node = self.clone() new_node = self.clone()
new_node.inputs = inputs new_node.inputs = new_inputs
return new_node return new_node
#convenience properties #convenience properties
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论