[Torchvision] DenseNetのPretrained WeightをLocalから読み込もうとしたら面倒だった話

MACHINE LEARNING

こんにちは

モデルを重みを使わずに初期化して、LocalからPretrained weightを読み込む必要があったのですが、少し面倒だったので解決までの手順を書き残そうと思います。

他に試したVisionTransformerやVGGでは同様の事態は発生していません。DenseNetでのみ発生しています。また、間違えている点があれば、コメントいただけると嬉しいです。

使用しているライブラリのバージョンは以下のとおりです。

torch==1.12.1
torchvision==0.13.1

モチベーション

研究で使っているGPUクラスタの容量の関係(?)でDenseNet161のPretrained Weightをダウンロード(引数でweights=DenseNet161_Weights)できなかったので、Localにあらかじめダウンロードして置けば読み込めるのでは、と思い試してみました。

発生したエラー

以下が入力と出力です。
モデルはtorchvisionのmodelsから引っ張ってきていて、Pretrained Weightも対応するもの(documentのsourceに書いてある、https://download.pytorch.org/models/densenet161-8d451a50.pth)を読み込んでいます。
普通に考えて問題なく読み込めそうですが、ダメでした。

import torch
from torchvision import models
model = models.densenet161(weights=None)
model.load_state_dict(torch.load('src/models/weight/densenet161-8d451a50.pth'))

>> RuntimeError: Error(s) in loading state_dict for DenseNet:
	Missing key(s) in state_dict: "features.denseblock1.denselayer1.norm1.weight", ..., "features.denseblock1.denselayer1.conv1.weight", ...
	Unexpected key(s) in state_dict: "features.denseblock1.denselayer1.norm.1.weight", ... ,"features.denseblock1.denselayer1.conv.1.weight", ...

定義したモデルのKeyと、Pretrained WeightのKeyがマッチしていないので読み込めないようです。

エラーを見るとどうやらモデルはfeatures.denseblock1.denselayer1.norm1.weightなのに、Pretrained Weightはfeatures.denseblock1.denselayer1.norm.1.weightとなっているようです。normとかconvの後に「.」が付くか付かないかの違いがあり、そのミスマッチがモデルのあらゆる層で発生しているということですね。

もちろんKeyが一致していないと重みは読み込めないので、なんとか統一する必要があります。

解決方法

Pretrained WeightのKeyを書き換える。

それでは、具体的にどのように書き換えるのかをみてみましょう。

sourceを見ていると、事前学習済みの重みに関係する以下のような関数と記述がありました。

»GitHubのDenseNet実装を見る

def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
    state_dict = weights.get_state_dict(progress=progress, check_hash=True)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)

昔のDenseLayerには「norm.1」とか「conv.1」みたいな「.」が付いたものがあるけど、それいまは使わないようにしているよ。とコメントには書いてあり、関数自体は昔のKeyを置き換えるもののようです。

引数として事前学習済みの重みの使用を指示すれば、この関数を経由して問題なくPretrained Modelが使えるわけですが、Localから読み込んだりする場合はもちろん使えません。昔のKeyのままなので。

そこで、私は以下のように書き、この問題を解決しました。

import torch
from torchvision import models

def _load_state_dict(model, weights):

    pattern = re.compile(
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
    state_dict = torch.load(weights)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)

model = models.densenet161(weights=None)
_load_state_dict(model = model,
                 weights = 'src/models/weight/densenet161-8d451a50.pth'))

公式の関数を少し簡略化して使ってみました。これで問題なくPretrained Weightが読み込めるはずです(多分)。

最後までお読みいただきありがとうございました。




コメント

タイトルとURLをコピーしました