【PyTorch】Dataset、Dataloaderの作成過程 詳細な理解とよくあるミス

私たちの目的はAIを利活用し業務を改善することですが、AIの勉強には高等数学の知識が必要になってきます。

画像分類の学習では、「テンソル」「正規分布(ガウス分布)」「標準化」「損失関数」「最適化関数」など非常に難しい単語が登場します。

AIの普及率が低いのはこれらの難解な数学の知識が邪魔しているのも一因かと思います。

この記事では、画像データ(CIFAR10)を用いながらDataset、Dataloaderの作成過程、特に「テンソル」「正規分布(ガウス分布)」「標準化」の言葉の意味について詳しく説明します。

CIFAR10を利用したpytorchの学習手順

CIFAR10は、10種類の「物体カラー写真」(乗り物や動物など)の画像データセットであり、画像分類を目的とした機械学習の研究やチュートリアルとして使用されています。

画像分類の手順は以下の通りです。

ここでは、「1.データを準備してDataset、Dataloaderを作成」についてみていきましょう。

Dataset、Dataloaderの作成過程コード

コードは以下の通りです。

画像分類について

そもそも画像をコンピューターで処理するということはどういうことなのか?

画像の仕組みから見ていきましょう。

写真は赤、緑、青の3つの行列の集まり

コンピューターに画像を読み込ませるには、写真というデータを数字に変換させる必要があります。

カラー写真(500*500ピクセルの場合)は赤色の500*500の行列、緑色の500*500の行列、青色の500*500の行列の3つの行列からできています。

一つ一つのピクセルは256段階(0=白~255=黒)の濃淡で表示され、コンピューターはこの数字で画像がどのような色になっているのか認識しています。

テンソル=ndarray

テンソルは行列を集めたものと考えられています。

テンソルはndarrayそのもので、ndarray形式で保存されています。

テンソルの演算において、pythonではNumpyというライブラリがあり、Numpyを使うことによって非常に効率よくコードを書くことができます。

GPUを使用する際はNumpyのndarray型ではなくPytorchのTensor型に変換する

テンソルだったり、tensorだったり非常にわかりにくいところですが、基本的にはNumpyもPytorchも同じものです。

NumpyもPytorchも画像(テンソル)を取り扱うライブラリですが、GPUを使用するときはPytorchのTensor型に変換する必要があると覚えておきましょう。

データの前処理

データの前処理の部分について詳しく見ていきます。

Pytorchでは画像データを(赤/緑/青の3色の次元, 高さ, 幅)で取り扱っており、transformオブジェクトでは以下の処理がされています。

transforms.ToTensor()でTensor型に変換

GPUを使うことを前提としているので、.ToTensor()でTensor型に変換しています。

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))で標準化

正規分布にしただけでは標準化は完了しません。

平均が0、標準偏差が1の標準正規分布に変換することで標準化が完了することに注意します。

サンプル画像を以下のコードで取得して正しくTensor型になっているのか確認してみましょう。

Tensor型になったのか確認

標準化ができたか確認

標準化が正しくできているかどうか以下のコードで確認します。

よくあるエラー

実際に自分で動かしてみて発生したエラーの対処法を紹介します。

RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0

上記のように打ち間違えると、次のエラーが発生します。

前述した標準化の部分のエラーです。

まとめ

Dataset、Dataloaderの作成過程では「テンソル」「正規分布(ガウス分布)」「標準化」などの難しい言葉が登場します。

全てを詳細に理解する必要はありませんが、知見として理解をしておくと便利です。