BLOG

シンラボメンバーのあれこれ

  1. HOME
  2. ブログ
  3. 代表通信
  4. 代表通信~Pytorch再びとStackGAN③

代表通信~Pytorch再びとStackGAN③

こんばんは。代表の草場です。

ラボ活は土日も好調に行われていました。ボードゲームアプリ化部隊が活発にやり取りしています。そしてTECH女子が「How to be a データサイエンティスト_経験ゼロからキャリアチェンジを考える」というイベントを開催。株式会社Rejouiとの共催です。素晴らしいイベントです。次回も楽しみにしています。

StackGANいじっていますが、Pytorchのデータの扱いとかわかってないな、ということで再度勉強。「pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用」がとても分かりやすかったので、写経しながら勉強しました。簡単に振り返ります。
データの前処理はtorchvisionのtransformsで行うわけですが、変換は、

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5, ))])

基本これだけ。torchvision.transforms.Composeは引数で渡されたlist型の[,,…]というのを先頭から順に実行します。ToTensor()でテンソルに、Normalize()で正規化。例えば2を足すtransformsを自作すると以下の感じ。

class Plus2(object):
  def __init__(self):
    pass
  def __call__(self, x):
    data = x + 2
    return data
trans = Plus2()
x = 9
data = trans(x)
print(x)
print(data)
output
9
11
transformsは「call()」という関数が重要で、この中に書いた処理が実行されます。
また、1から6の数字のdataと、奇数が0偶数が1となるlabelを持つDatasetを自作すると以下の感じです。
class Mydatasets(torch.utils.data.Dataset):
  def __init__(self, transform = None):
    self.transform = transform
    self.data = [1, 2, 3, 4, 5, 6]
    self.label = [0, 1, 0, 1, 0, 1]
    self.datanum = 6
  def __len__(self):
    return self.datanum
  def __getitem__(self, idx):
    out_data = self.data[idx]
    out_label = self.label[idx]
    if self.transform:
      out_data = self.transform(out_data)
    return out_data, out_label
実際に使うには以下。
t= Plus2()
dataset = Mydatasets(t)
print(len(dataset))
print(dataset[5])
print(dataset)
output
6
(8, 1)
<__main__.Mydatasets object at 0x7fb72d4d0c50>

大事なのは、classの「def len()」は「len(dataset)」で実行されるdataの長さを返す関数で、「def getitem()」は「dataset[5]」とするとその番号のdataとlabelを返す関数。このようにdatasetsは「def len()」と「def getitem()」が必須、とのこと。

だいぶんわかってきた。
さて、懲りずに馬の生成をStackGANで試みる。
two horses with red feathers on top of their heads(頭の上に赤い羽根のついた二頭の馬)

待て待て待て、馬のクセがすごい!(千鳥調)

明日も大グセじゃ。

関連記事