')">

인공지능 모형 만들기 II

모형의 학습과 모니터링

Posted by Jong-June Jeon on June 29, 2024

예시

다음은 argparse 와 wandb 를 이용해서 모형을 적합하고 모니터링 하는 예 입니다.

라이브러리 파일 (libimg.py)

# 코드
import torch
import torch.nn as nn
import argparsel
import torch.optim as optim
import torch.nn.functional as F
import wandb

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
  
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def modelFit(model, trainloader, opt, device): 
    epochs = opt.epochs
    learning_rate  = opt.lr

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)


    for epoch in range(epochs):  # loop over the dataset multiple times
    
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
    
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            if i % 20 == 19:    # print every 20 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20:.3f}')
                wandb.log({"Training loss": running_loss / 20})
                running_loss = 0.0
    
    return running_loss
  
def parse_option():

    parser = argparse.ArgumentParser('argument for training')
    parser.add_argument('--lr', type=float, default=0.001, 
                        help='learning rate')
    parser.add_argument('--batch_size', type=int, default=128, 
                        help='learning rate')
    parser.add_argument('--epochs', type=int, default=10, 
                        help='epochs')
    parser.add_argument('--name', type=str, default="default", 
                        help='run_name')    
    opt = parser.parse_args()
    
    return opt
      

main 실행 파일

# 실행파일의 작성
import torch
import torchvision
import torchvision.transforms as transforms
import wandb
import libimg 


opt = libimg.parse_option()
print("Run Name:", opt.name)
print("Batch Size:", opt.batch_size)
print("Learning Rate:", opt.lr)
print("Epoch:", opt.epochs)

wandb.init(project='CIFAR10 Classification')

args = {
    "learning_rate": opt.lr,
    "epochs": opt.epochs,
    "batch_size": opt.batch_size
}
wandb.config.update(args)


wandb.run.name = opt.name
# 현재 실행의 상태를 저장합니다. 
wandb.run.save()
#%%
batch_size = opt.batch_size
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
 
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)
net = libimg.Net().to(device)
opt  = libimg.parse_option()
v = libimg.modelFit(net, trainloader, opt, device)

# 모델 weight 저장
print('Finished Training')
wandb.finish()  

커맨드 실행

python wandb_run2.py --name sky_blue2 --epochs 50