提交 106625e3 authored 作者: Frederic Bastien's avatar Frederic Bastien

make the ScanGrad op handle ProfileMode as the Scan op.

上级 40db8188
...@@ -1021,12 +1021,14 @@ class Scan(Op): ...@@ -1021,12 +1021,14 @@ class Scan(Op):
#we must make a copy otherwise in the profile their will time counted many times #we must make a copy otherwise in the profile their will time counted many times
#1) The scan op and its time will include all time spend into the inner node. #1) The scan op and its time will include all time spend into the inner node.
#2) The inner scan op with their real time. #2) The inner scan op with their real time.
#This is done for the Scan and ScanGred op
if mode is None and isinstance(mode_instance, compile.profilemode.ProfileMode): if mode is None and isinstance(mode_instance, compile.profilemode.ProfileMode):
mode_instance = compile.profilemode.ProfileMode( mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer, optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker) linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(mode_instance) compile.profilemode.prof_mode_instance_to_print.append(mode_instance)
self.mode_instance = mode_instance self.mode_instance = mode_instance
self.mode_instance.message="Scan sub profile"
if name is None: name = 'scan_fn' if name is None: name = 'scan_fn'
self.fn = function(inputs,outputs, mode = mode_instance, givens = givens, self.fn = function(inputs,outputs, mode = mode_instance, givens = givens,
...@@ -1453,10 +1455,13 @@ class ScanGrad(Op): ...@@ -1453,10 +1455,13 @@ class ScanGrad(Op):
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs, def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
n_outs_not_shared, n_outs_not_shared,
go_backwards = False, seqs_taps = {}, outs_taps= {}, go_backwards = False, seqs_taps = {}, outs_taps= {},
truncate_gradient = -1): truncate_gradient = -1, mode = None, name = None):
"""
:param mode: see scan fct
:param name: see scan fct
"""
self.grad_fn = function(g_ins, g_outs)
self.inputs = g_ins self.inputs = g_ins
self.outputs = g_outs self.outputs = g_outs
self.n_outs_not_shared = n_outs_not_shared self.n_outs_not_shared = n_outs_not_shared
...@@ -1467,7 +1472,24 @@ class ScanGrad(Op): ...@@ -1467,7 +1472,24 @@ class ScanGrad(Op):
self.seqs_taps = seqs_taps self.seqs_taps = seqs_taps
self.outs_taps = outs_taps self.outs_taps = outs_taps
self.destroy_map = {} self.destroy_map = {}
self.mode = mode
mode_instance = compile.mode.get_mode(mode)
#if we use the default mode and it is a ProfileMode
#we must make a copy otherwise in the profile their will time counted many times
#1) The scan op and its time will include all time spend into the inner node.
#2) The inner scan op with their real time.
#This is done for the Scan and ScanGred op
if mode is None and isinstance(mode_instance, compile.profilemode.ProfileMode):
mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(mode_instance)
self.mode_instance = mode_instance
self.mode_instance.message="ScanGrad sub profile"
if name is None: name = 'scan_grad_fn'
self.grad_fn = function(g_ins, g_outs, mode = mode_instance, name = name)
def __eq__(self,other): def __eq__(self,other):
rval = type(self) == type(other) rval = type(self) == type(other)
...@@ -1479,6 +1501,7 @@ class ScanGrad(Op): ...@@ -1479,6 +1501,7 @@ class ScanGrad(Op):
(self.go_backwards == other.go_backwards) and \ (self.go_backwards == other.go_backwards) and \
(self.n_outs_not_shared == other.n_outs_not_shared) and\ (self.n_outs_not_shared == other.n_outs_not_shared) and\
(self.truncate_gradient == other.truncate_gradient) and\ (self.truncate_gradient == other.truncate_gradient) and\
(self.mode == other.mode) and \
(self.seqs_taps == other.seqs_taps) and \ (self.seqs_taps == other.seqs_taps) and \
(self.outs_taps == other.outs_taps) (self.outs_taps == other.outs_taps)
return rval return rval
...@@ -1489,6 +1512,7 @@ class ScanGrad(Op): ...@@ -1489,6 +1512,7 @@ class ScanGrad(Op):
hash(self.n_outs) ^ \ hash(self.n_outs) ^ \
hash(self.go_backwards) ^\ hash(self.go_backwards) ^\
hash(self.truncate_gradient) ^\ hash(self.truncate_gradient) ^\
hash(self.mode) ^\
hash_listsDictsTuples(self.inputs) ^ \ hash_listsDictsTuples(self.inputs) ^ \
hash_listsDictsTuples(self.outputs) ^ \ hash_listsDictsTuples(self.outputs) ^ \
hash_listsDictsTuples(self.seqs_taps) ^ \ hash_listsDictsTuples(self.seqs_taps) ^ \
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论