PyTorch 之 Checkpoint 機(jī)制解析
以下文章來(lái)源于極市平臺(tái) ,作者CV開(kāi)發(fā)者都愛(ài)看的
作者丨Lart
編輯丨極市平臺(tái)
導(dǎo)讀
PyTorch 提供了一種非常方便的節(jié)省顯存的方式,就是 Checkpoint 機(jī)制。這篇文章的目的在于更透徹的了解其內(nèi)在的機(jī)制。
Checkpoint 機(jī)制
該技術(shù)的核心是一種使用時(shí)間換空間的策略。在現(xiàn)有的許多方法中被大量使用,例如 DenseNet、Swin Transformer 源碼中都可以看到它的身影。
為了了解它的工作原理,我們先得弄明白的一個(gè)問(wèn)題是,PyTorch 模型在訓(xùn)練過(guò)程中顯存占用主要是用來(lái)存儲(chǔ)什么?
關(guān)于這一點(diǎn),Connolly 的文章 《PyTorch 顯存機(jī)制分析》(https://zhuanlan.zhihu.com/p/424512257) 介紹的非常詳細(xì):
開(kāi)門(mén)見(jiàn)山的說(shuō),PyTorch 在進(jìn)行深度學(xué)習(xí)訓(xùn)練的時(shí)候,有 4 大部分的顯存開(kāi)銷,分別是模型參數(shù)(parameters),模型參數(shù)的梯度(gradients),優(yōu)化器狀態(tài)(optimizer states) 以及 中間激活值(intermediate activations) 或者叫中間結(jié)果(intermediate results)。
而通過(guò) Checkpoint 技術(shù),我們可以通過(guò)一種取巧的方式,使用 PyTorch 提供的 “no-grad” (no_grad())模式來(lái)避免將這部分運(yùn)算被autograd記錄到反向圖“backward graph”中,從而避免了對(duì)于中間激活值的存儲(chǔ)需求。
個(gè)人理解(歡迎指出錯(cuò)誤):
前向傳播時(shí) autograd 記錄各個(gè)操作反向傳播需要的一些信息和中間變量。反向傳播之后,用于計(jì)算梯度的中間結(jié)果會(huì)被釋放。也就是說(shuō),模型參數(shù)、優(yōu)化器狀態(tài)和參數(shù)梯度是始終在占用著存儲(chǔ)空間的,中間激活值在反向傳播之后就自動(dòng)被清空了。這里我簡(jiǎn)單修改了 《PyTorch 顯存機(jī)制分析》(https://zhuanlan.zhihu.com/p/424512257) 中給出的例子 進(jìn)行了一下驗(yàn)證(https://github.com/lartpang/CodeForArticle/tree/main/CheckpointAndGPUUsage.PyTorch) 。
這里實(shí)際上會(huì)引申出另一個(gè)問(wèn)題,為什么自定義 Function 一般情況下會(huì)減少顯存占用?(在 Vision Longformer 中各種實(shí)現(xiàn)的對(duì)比里可以明顯看到這一現(xiàn)象)
我覺(jué)得主要是因?yàn)樽远x Function 的時(shí)候,我們可以從一整個(gè)模塊的角度來(lái)更有針對(duì)性的在 ctx 中存儲(chǔ)中間變量,而自動(dòng)求導(dǎo)引擎可能關(guān)注的太細(xì)了,導(dǎo)致存儲(chǔ)許多不必要的中間變量。關(guān)于這一點(diǎn)暫時(shí)不知道如何驗(yàn)證。
這可以避免存儲(chǔ)模型特定層中間運(yùn)算結(jié)果,從而有效降低了前向傳播中顯存的占用。 這些中間結(jié)果會(huì)在反向傳播的時(shí)候被即時(shí)重新計(jì)算一次。要注意,被 checkpoint 包裹的層反向傳播時(shí)仍然會(huì)在第一次反向傳播的時(shí)候開(kāi)辟存儲(chǔ)梯度的空間。
因?yàn)?checkpoint 是在 torch.no_grad() 模式下計(jì)算的目標(biāo)操作的前向函數(shù),這并不會(huì)修改原本的葉子結(jié)點(diǎn)的狀態(tài),有梯度的還會(huì)保持。只是關(guān)聯(lián)這些葉子結(jié)點(diǎn)的臨時(shí)生成的中間變量會(huì)被設(shè)置為不需要梯度,因此梯度鏈?zhǔn)疥P(guān)系會(huì)被斷開(kāi)。
通過(guò)這樣的方式,雖然延長(zhǎng)了反向傳播的時(shí)間,但是卻也在一定程度上緩解了存儲(chǔ)大量中間變量帶來(lái)的顯存占用。
源碼解析
以下代碼來(lái)自 PyTorch v1.10.1 版本:https://github.com/pytorch/pytorch/blob/v1.10.1/torch/utils/checkpoint.py。最新的版本中補(bǔ)充了一些新的內(nèi)容,待其最終發(fā)布后再說(shuō)吧,下面的內(nèi)容本身已經(jīng)將 checkpoint 的核心介紹了。
輔助函數(shù)
這部分代碼中首先構(gòu)造了數(shù)個(gè)輔助函數(shù),主要是用來(lái)做一些針對(duì)輸入的檢查和處理,同時(shí)也要處理好隨機(jī)種子的問(wèn)題。
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): out = [] for inp in inputs: if not isinstance(inp, torch.Tensor): out.append(inp) continue # 直接detach(),從inp所在的計(jì)算圖中剝離,默認(rèn)會(huì)自動(dòng)將requires_grad置為False x = inp.detach() # 但是這里的實(shí)際需求中,仍需要保持其自身的需要記錄梯度的屬性,且其梯度變?yōu)镹one x.requires_grad = inp.requires_grad # 因?yàn)橹挥行枰4嫣荻鹊膮?shù)才能夠構(gòu)建梯度的傳播路徑 out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) def check_backward_validity(inputs: Iterable[Any]) -> None: """檢查輸入?yún)?shù)是否至少有一個(gè)需要記錄梯度的Tensor,這樣才能確保輸出也有梯度。""" if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
由于需要重復(fù)計(jì)算,所以隨機(jī)狀態(tài)的一致性是需要重視的。由于前向傳播的部分在反向過(guò)程中仍會(huì)計(jì)算一次,所以如果不使用原始的隨機(jī)狀態(tài)的話,會(huì)導(dǎo)致重新計(jì)算和原本正常計(jì)算過(guò)程中的隨機(jī)狀態(tài)不同,而影響模型的行為。
另外在這段代碼的注釋中提到了一點(diǎn)有趣的地方:
由于無(wú)法獲悉被 checkpoint 處理的操作是否在運(yùn)算中間會(huì)將一些參數(shù)移動(dòng)到不同的設(shè)備上,這可能需要手動(dòng)保存這些設(shè)備對(duì)應(yīng)的隨機(jī)狀態(tài)。當(dāng)前的實(shí)現(xiàn)直接保存了所有可見(jiàn)設(shè)備上的隨機(jī)狀態(tài),但是這樣有時(shí)可能是不必要的,但是目前尚沒(méi)有較好的解決策略。
所以按照文檔的意思,就是在說(shuō)如果沒(méi)有這樣的移動(dòng),那就可以不用保存隨機(jī)狀態(tài)咯?這一點(diǎn)其實(shí)有些令人疑惑。
# We can't know if the run_fn will internally move some args to different devices, # which would require logic to preserve rng states for those devices as well. # We could paranoically stash and restore ALL the rng states for all visible devices, # but that seems very wasteful for most cases. Compromise: Stash the RNG state for # the device of all Tensor args. # # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: """獲取不同輸入對(duì)應(yīng)的GPU設(shè)備的隨機(jī)數(shù)生成器的狀態(tài)""" # This will not error out if "arg" is a CPU tensor or a non-tensor type because # the conditionals short-circuit. fwd_gpu_devices = list(set(arg.get_device() for arg in args if isinstance(arg, torch.Tensor) and arg.is_cuda)) fwd_gpu_states = [] for device in fwd_gpu_devices: with torch.cuda.device(device): fwd_gpu_states.append(torch.cuda.get_rng_state()) return fwd_gpu_devices, fwd_gpu_states def set_device_states(devices, states) -> None: """針對(duì)不同的設(shè)備設(shè)置隨機(jī)數(shù)生成器的狀態(tài)""" for device, state in zip(devices, states): with torch.cuda.device(device): torch.cuda.set_rng_state(state)
核心 Function
可以看到,這里的 Checkpoint 本身就是基于 PyTorch 的 Function 實(shí)現(xiàn)的一個(gè)擴(kuò)展算子,所以該部分代碼也涉及到了 Function 的諸多功能。閱讀它既可以幫助我們同時(shí)復(fù)習(xí)一下相關(guān)的知識(shí),又能進(jìn)一步了解更復(fù)雜的處理邏輯該如何搭建。
class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) # 暫存前向傳播函數(shù) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # 用來(lái)保存當(dāng)前模型的混合精度的狀態(tài),以用在反向傳播中 ctx.had_autocast_in_fwd = torch.is_autocast_enabled() if preserve_rng_state: # 保存目標(biāo)模塊前向傳播之前,此時(shí)CPU和GPU的隨機(jī)數(shù)生成器的狀態(tài) ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: # PyTorch提供的一個(gè)內(nèi)部變量,用于判定CUDA狀態(tài)是否已經(jīng)被初始化了 # torch.cuda.is_initialized中就用到了該變量 ctx.had_cuda_in_fwd = True # 保存輸入變量涉及的各個(gè)GPU設(shè)備的隨機(jī)狀態(tài) ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) # save_for_backward()中保存反向傳播中需要用到的輸入和輸出tensor量。 # 由于在反向傳播中需要重新計(jì)算記錄梯度的output,所以就不要保存output了。 # 并且后面的計(jì)算也不需要在梯度模式下計(jì)算。 ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): # 不保存梯度的前向傳播操作,也就是說(shuō)這里的output是不會(huì)記錄中間變量,無(wú)法直接計(jì)算梯度的。 outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors # 獲取前向傳播中保存的輸入tensor # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices # 使用之前前向傳播開(kāi)始之前保存的隨機(jī)數(shù)生成器的狀態(tài)來(lái)進(jìn)行一次一模一樣的前向傳播過(guò)程 with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): # 使用上下文管理器保護(hù)原始的隨機(jī)數(shù)生成器的狀態(tài),內(nèi)部處理后在進(jìn)行復(fù)原 if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) # 這里將inputs從計(jì)算圖中剝離開(kāi),但是其屬性requires_grad和原來(lái)是一樣的,這么做的目的是為了截?cái)喾聪騻鞑サ穆窂健? # 從整個(gè)操作目的來(lái)看,由于我們需要重新計(jì)算輸出,并將梯度回傳到輸入上,所以輸入本身需要可以記錄梯度。 # 但是這里的回傳不可以影響到checkpoint之外更靠前的那些操作, # backward之后會(huì)將之前保存的中間變量釋放掉,而我們僅僅是為了計(jì)算當(dāng)前一小塊結(jié)構(gòu),所以梯度回傳需要截?cái)唷? detached_inputs = detach_variable(tuple(inputs)) # 會(huì)變成葉子結(jié)點(diǎn),grad和grad_fn均重置為None # 處理完隨機(jī)狀態(tài)之后,就該準(zhǔn)備著手重新前向傳播了。 # 這次前向傳播是在梯度模式(torch.enable_grad())下執(zhí)行的。此時(shí)會(huì)保存中間變量。 with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) # run backward() with only tensor that requires grad outputs_with_grad = [] args_with_grad = [] for i in range(len(outputs)): # 記錄需要計(jì)算梯度的輸出outputs[i]以及對(duì)應(yīng)的回傳回來(lái)的有效梯度args[i] if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) # 檢查需要計(jì)算梯度的輸出,如果沒(méi)有輸出需要計(jì)算梯度,那么實(shí)際上就說(shuō)明這個(gè)模塊是不參與梯度計(jì)算的, # 也就是說(shuō),該模塊不需要使用checkpoint來(lái)調(diào)整。 if len(outputs_with_grad) == 0: raise RuntimeError( "none of output has requires_grad=True," " this checkpoint() is not necessary") # 該操作對(duì)被包裹的目標(biāo)操作計(jì)算反向傳播,即計(jì)算回傳到輸入detached_inputs上的梯度。 # 由于輸入的tensor已被從整體梯度圖中剝離,所以可以看做是一個(gè)葉子結(jié)點(diǎn),可以在反向傳播之后獲得其梯度,并且中間變量也會(huì)隨之釋放。 # 另外這里反傳計(jì)算梯度也不會(huì)導(dǎo)致將更靠前的結(jié)構(gòu)中暫時(shí)保存來(lái)計(jì)算梯度的參數(shù)給釋放掉。 torch.autograd.backward(outputs_with_grad, args_with_grad) # 如果前面不執(zhí)行detach(),這里的inp.grad會(huì)被直接釋放并置為None,這并不符合預(yù)期 grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) # 這里返回的梯度與當(dāng)前類的forward輸入一一對(duì)應(yīng), # 由于這里的forward包含著本不需要梯度的兩個(gè)參數(shù)run_function、preserve_rng_state,故對(duì)應(yīng)回傳None即可。 return (None, None) + grads
這里實(shí)際上就是在原始的操作和整體的計(jì)算圖之間添加了一個(gè)中間層,用于信息的交互:
原始模型的數(shù)據(jù)傳輸?shù)奖话哪繕?biāo)層的時(shí)候,數(shù)據(jù)進(jìn)入 checkpoint 的 forward() 中,被 checkpoint 進(jìn)行檢查和記錄后,再送入目標(biāo)層中;
目標(biāo)層在非梯度模式下執(zhí)行前向傳播。該模式下,新創(chuàng)建的 tensor 都是不會(huì)記錄梯度信息的;
目標(biāo)層的結(jié)果通過(guò) checkpoint 的前向傳播輸出,送入模型后續(xù)的其他結(jié)構(gòu)中;
執(zhí)行反向傳播,損失求導(dǎo),鏈?zhǔn)交貍?,?jì)算梯度;
回傳回來(lái)的對(duì)應(yīng)于 checkpoint 輸出的梯度被送入其對(duì)應(yīng)的反向傳播函數(shù),即 checkpoint 的 backward()。
梯度送入 checkpoint 中后,需要進(jìn)一步將梯度回傳到目標(biāo)層的輸入上。由于在 checkpoint 的 forward 中目標(biāo)層本身前向傳播是處于非梯度狀態(tài)下,所以回傳路徑上缺少目標(biāo)層中操作的梯度子圖。于是為了獲取這部分信息,需要先梯度狀態(tài)下對(duì)目標(biāo)層進(jìn)行一次前向傳播,通過(guò)將回傳回來(lái)的梯度和目標(biāo)層的輸出一起執(zhí)行 torch.autograd.backward(outputs_with_grad, args_with_grad),從而獲得對(duì)應(yīng)輸入的梯度信息。
將對(duì)應(yīng)目標(biāo)操作輸入的梯度信息按照 checkpoint 本身 Function 的 backward 的需求,使用 None 對(duì)其他輔助參數(shù)的梯度占位后進(jìn)行返回。
返回的對(duì)應(yīng)于其他模塊的輸出量的梯度,被沿著反向傳播的路徑送入對(duì)應(yīng)操作的 backward 中,一層一層回傳累加到各個(gè)葉子節(jié)點(diǎn)上。
定義好操作后,進(jìn)行一個(gè)簡(jiǎn)單的包裝,同時(shí)處理一下默認(rèn)參數(shù),補(bǔ)充了更細(xì)致的文檔:
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does **not** save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, :attr:`function` will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the :attr:`function` parameter. In the backwards pass, the saved inputs and :attr:`function` is retrieved, and the forward pass is computed on :attr:`function` again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.
這一段詳細(xì)介紹了checkpoint的核心技術(shù),也就是在非梯度模式下執(zhí)行目標(biāo)操作的前向傳播,只保留輸入和結(jié)構(gòu)參數(shù),省去了中間激活的保存。反向傳播時(shí)在梯度模式下重新計(jì)算這些激活,重建這部分反向圖,進(jìn)而實(shí)現(xiàn)了梯度的正?;貍鳌? The output of :attr:`function` can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. 因?yàn)閏heckpoint的backward實(shí)現(xiàn)的邏輯中,直接遍歷目標(biāo)操作的輸出(會(huì)被自定轉(zhuǎn)換成元組類型)并確定那些需要回流梯度的輸出。如果輸出中包含其他的非tensor結(jié)構(gòu),就會(huì)導(dǎo)致在遍歷過(guò)程中這些輸出被忽略掉。不過(guò)也確實(shí),這樣直接簡(jiǎn)化處理雖然使得靈活性下降,但是卻也避免了代碼過(guò)于復(fù)雜。 .. warning:: Checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. .. warning:: If :attr:`function` invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won't be equivalent, and unfortunately it can't be detected. 盡量保證目標(biāo)操作在反向計(jì)算期間和前向期間的操作的一致性。 因?yàn)樵赾heckpoint會(huì)在反向中重新計(jì)算一次前向,這可能會(huì)帶來(lái)一些由于無(wú)法檢測(cè)到的不確定因素而造成的與常規(guī)版本的差異。 .. warning:: If checkpointed segment contains tensors detached from the computational graph by `detach()` or `torch.no_grad()`, the backward pass will raise an error. This is because `checkpoint` makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the `checkpoint` function. 不要在目標(biāo)操作中包含detach或者非梯度模式的處理。 **在我的實(shí)際測(cè)試中似乎并沒(méi)有這個(gè)問(wèn)題?**或許這里應(yīng)該看一下pytorch提供的測(cè)試案例。 .. warning:: At least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. At least one of the outputs needs to have :code:`requires_grad=True` as well. 要保證至少有一個(gè)輸入是requires_grad的,這樣才可以保證這部分操作可以被記錄梯度。 也要保證輸出至少有一個(gè)需要計(jì)算梯度。 Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional, default=True): Omit stashing and restoring the RNG state during each checkpoint. args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr:`function` on :attr:`*args` """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) return CheckpointFunction.apply(function, preserve, *args)
應(yīng)用案例
Checkpoint for Sequential
PyTorch 源碼中給了一個(gè)很直接的應(yīng)用案例,就是將 checkpoint 應(yīng)用于 Sequential 搭建起來(lái)的模型。按照分段數(shù) segments 指定的,將模型劃分為多段。
def checkpoint_sequential(functions, segments, input, **kwargs): r"""A helper function for checkpointing sequential models. Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. .. warning:: Checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. .. warning: At least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. .. warning: Since PyTorch 1.4, it allows only one Tensor as the input and intermediate outputs, just like :class:`torch.nn.Sequential`. Args: functions: A :class:`torch.nn.Sequential` or the list of modules or functions (comprising the model) to run sequentially. segments: Number of chunks to create in the model input: A Tensor that is input to :attr:`functions` preserve_rng_state(bool, optional, default=True): Omit stashing and restoring the RNG state during each checkpoint. Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` Example: >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) """ # Hack for keyword-only parameter in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) def run_function(start, end, functions): def forward(input): for j in range(start, end + 1): input = functions[j](input) return input return forward if isinstance(functions, torch.nn.Sequential): functions = list(functions.children()) # 獲取Sequential的子模塊,這里使用children方法,僅獲取最外層 segment_size = len(functions) // segments # the last chunk has to be non-volatile (為什么?似乎加上也是可以的) end = -1 for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 # 迭代式的將各個(gè)子模塊集合使用checkpoint包裝并前向傳播。 input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) # 剩余的結(jié)構(gòu)不再使用checkpoint return run_function(end + 1, len(functions) - 1, functions)(input)
參考鏈接
Checkpoint 源碼:https://github.com/pytorch/pytorch/blob/master/torch/utils/checkpoint.py
PyTorch 的 Autograd - xiaopl 的文章 - 知乎 https://zhuanlan.zhihu.com/p/69294347
PyTorch 源碼解讀之 torch.autograd:梯度計(jì)算詳解 - OpenMMLab 的文章 - 知乎 https://zhuanlan.zhihu.com/p/321449610
淺談 PyTorch 中的 tensor 及使用 - xiaopl 的文章 - 知乎 https://zhuanlan.zhihu.com/p/67184419
https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc
https://pytorch.org/tutorials/beginner/introyt/autogradyt_tutorial.html
本文僅做學(xué)術(shù)分享,如有侵權(quán),請(qǐng)聯(lián)系刪文。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。
pwm相關(guān)文章:pwm原理