画像を正規化:Pythonコード 記事は雑です。
どうも、最近、○○学という雑誌にニューラルネットワークを使った論文がアクセプトされました。カミングアウトすみません。
画像を正規化:Pythonコード
意味不明と思っているかたも多いかと思います。
正規化は機械学習をするにあたり重要な処理の一つです。PyTorchを使いました。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
class calc_mean_std:
def __init__(self) :
print("calc_img")def calc(self, train_data,batch_size,num_workers):
data_transform = transforms.Compose([transforms.ToTensor()])
train_data.dataset.transform = data_transform
dataloader = DataLoader(train_data, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
mean = 0.0
for images, _ in dataloader:
batch_samples = images.size(0)
images = images.view(batch_samples,images.size(1),-1)
mean += images.mean(2).sum(0)
mean = mean/len(dataloader.dataset)var = 0.0
for images, _ in dataloader:
batch_samples = images.size(0)
images = images.view(batch_samples,images.size(1),-1)
var += *1**2).sum([0,2])
std = torch.sqrt(var/(len(dataloader.dataset)*512*512))
print(mean,std)return mean,std
大体はコピペで使用出来ます。
*1:images - mean.unsqueeze(1