代表通信~Pytorch再びとStackGAN③
こんばんは。代表の草場です。
ラボ活は土日も好調に行われていました。ボードゲームアプリ化部隊が活発にやり取りしています。そしてTECH女子が「How to be a データサイエンティスト_経験ゼロからキャリアチェンジを考える」というイベントを開催。株式会社Rejouiとの共催です。素晴らしいイベントです。次回も楽しみにしています。
StackGANいじっていますが、Pytorchのデータの扱いとかわかってないな、ということで再度勉強。「pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用」がとても分かりやすかったので、写経しながら勉強しました。簡単に振り返ります。
データの前処理はtorchvisionのtransformsで行うわけですが、変換は、
trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5, ))])
基本これだけ。torchvision.transforms.Composeは引数で渡されたlist型の[,,…]というのを先頭から順に実行します。ToTensor()でテンソルに、Normalize()で正規化。例えば2を足すtransformsを自作すると以下の感じ。
class Plus2(object):
def __init__(self):
pass
def __call__(self, x):
data = x + 2
return data
trans = Plus2()
x = 9
data = trans(x)
print(x)
print(data)
class Mydatasets(torch.utils.data.Dataset):
def __init__(self, transform = None):
self.transform = transform
self.data = [1, 2, 3, 4, 5, 6]
self.label = [0, 1, 0, 1, 0, 1]
self.datanum = 6
def __len__(self):
return self.datanum
def __getitem__(self, idx):
out_data = self.data[idx]
out_label = self.label[idx]
if self.transform:
out_data = self.transform(out_data)
return out_data, out_label
t= Plus2()
dataset = Mydatasets(t)
print(len(dataset))
print(dataset[5])
print(dataset)
大事なのは、classの「def len()」は「len(dataset)」で実行されるdataの長さを返す関数で、「def getitem()」は「dataset[5]」とするとその番号のdataとlabelを返す関数。このようにdatasetsは「def len()」と「def getitem()」が必須、とのこと。
だいぶんわかってきた。
さて、懲りずに馬の生成をStackGANで試みる。
two horses with red feathers on top of their heads(頭の上に赤い羽根のついた二頭の馬)
待て待て待て、馬のクセがすごい!(千鳥調)
明日も大グセじゃ。
EVENTS