# 노트북에서 Training Operator를 활용하여 병렬 학습 모델 구현하기

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from kakaocloud_kbm.training import TrainingClient
from kakaocloud_kbm.training.utils.utils import get_default_target_namespace



## Fashion MNIST CNN 모델 학습 함수 선언

- Training Job에 넣어줄 간단한 CNN 모델 학습 함수를 선언합니다
- torchvision 패키지를 통해 Fashion MNIST 데이터를 다운로드 코드 포함

In [2]:
def train_pytorch_model():
    import logging
    import time
    import datetime
    import os  
    import torch  
    import torch.multiprocessing as mp  
    from torch import nn  
    from torch.distributed import init_process_group, destroy_process_group  
    from torch.nn.parallel import DistributedDataParallel as DDP  
    from torch.utils.data import DataLoader, Dataset  
    from torch.utils.data.distributed import DistributedSampler  
    from torchvision import datasets  
    from torchvision.transforms import ToTensor  

    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%SZ",
        level=logging.DEBUG,
    )
        
    MAX_EPOCHS = 100
    SAVE_EVERY = 1
    BATCH_SIZE = 32  
    device = torch.device("cuda")

    class Trainer:  
        def __init__(  
            self,  
            model: torch.nn.Module,  
            train_data: DataLoader,  
            optimizer: torch.optim.Optimizer,  
        ) -> None:  
            self.model = DDP(model.to(device))  
            self.train_data = train_data  
            self.optimizer = optimizer
            self.start_time = None
            self.end_time = None
            self.duration = None

        def _run_epoch(self, epoch):  
            b_sz = len(next(iter(self.train_data))[0])  
            logging.info(f"b_sz : {b_sz} / epoch : {epoch} / len data : {len(self.train_data)}")  
            for source, targets in self.train_data:  
                source = source.to(device)

                targets = targets.to(device)  
                self.optimizer.zero_grad()
                                
                output = self.model(source)  
                loss = torch.nn.CrossEntropyLoss()(output, targets)  
                loss.backward()  
                self.optimizer.step()  

        def _save_checkpoint(self, epoch):  
            ckp = self.model.module.state_dict()  
            torch.save(ckp, "ckpt.pt")  
            logging.info(f"Epoch {epoch} | Training ckpt saved at ckpt.pt")  

        def train(self):
            self.start_time = time.time()
            for epoch in range(MAX_EPOCHS):  
                self._run_epoch(epoch)  
                if epoch % SAVE_EVERY == 0:  
                    self._save_checkpoint(epoch)
                    
            self.end_time = time.time()
            sec = self.end_time - self.start_time
            self.duration = str(datetime.timedelta(seconds=sec))
            logging.info(f"{self.duration} sec")


    class NeuralNetwork(nn.Module):  
        def __init__(self):  
            super().__init__()  
            self.flatten = nn.Flatten()  
            self.linear_relu_stack = nn.Sequential(  
                nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10)  
            )  

        def forward(self, x):  
            x = self.flatten(x)  
            logits = self.linear_relu_stack(x)  
            return logits  


    def load_train_dataset_model_and_opt():  
        train_set = datasets.FashionMNIST(  
            root="data",  
            train=True,  
            download=True,  
            transform=ToTensor(),  
        )  

        model = NeuralNetwork()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
        return train_set, model, optimizer  


    def prepare_dataloader(dataset: Dataset, batch_size: int):  
        return DataLoader(  
            dataset, batch_size=batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset)  
        )

    global main
    def main(rank: int):
        init_process_group(backend="nccl")
        dataset, model, optimizer = load_train_dataset_model_and_opt()  
        logging.info(f"WORLD_SIZE : {os.environ['WORLD_SIZE']} / len(dataset) : {len(dataset)}")
        train_data = prepare_dataloader(dataset, batch_size=BATCH_SIZE)
        
        trainer = Trainer(model, train_data, optimizer)
        trainer.train()
        destroy_process_group()
            
    if __name__ == "__main__":        
        torch.cuda.empty_cache()
        mp.spawn(  
            main, 
        )


## Training Job 실행

In [5]:
# VARIABLES
my_namespace = get_default_target_namespace()
pytorchjob_name = "parallel-train-pytorch"
gpu_mig_for_1ea = {
    "nvidia.com/mig-1g.10gb": "1",
    "cpu": "2",
    "memory": "4G"

}
num_workers = 5  # 병렬학습 GPU 수

In [6]:
training_client = TrainingClient()

training_client.create_pytorchjob_from_func(
    name=pytorchjob_name,
    namespace=my_namespace,
    func=train_pytorch_model,
    base_image="bigdata-150.kr-central-2.kcr.dev/kc-kubeflow/pytorchjob-pytorch:1.12.1-cuda11.3-cudnn8-runtime",
    num_worker_replicas=num_workers-1,  # Worker 노드 수(1대는 Master) = 의도한 병렬학습 GPU - 1
    limit_resources=gpu_mig_for_1ea
)

INFO:root:PyTorchJob kbm-u-admin-jin/parallel-train-pytorch2 has been created


## Training Job 상태 확인

In [7]:
# STATUS DETAILS
print(training_client.get_job_conditions(name=pytorchjob_name, job_kind='PyTorchJob'))

# RUN CHECK
print(f"Is job running: {training_client.is_job_running(name=pytorchjob_name, job_kind='PyTorchJob')}")

[{'last_transition_time': datetime.datetime(2024, 4, 20, 3, 49, 38, tzinfo=tzutc()),
 'last_update_time': datetime.datetime(2024, 4, 20, 3, 49, 38, tzinfo=tzutc()),
 'message': 'PyTorchJob parallel-train-pytorch2 is created.',
 'reason': 'PyTorchJobCreated',
 'status': 'True',
 'type': 'Created'}, {'last_transition_time': datetime.datetime(2024, 4, 20, 3, 49, 40, tzinfo=tzutc()),
 'last_update_time': datetime.datetime(2024, 4, 20, 3, 49, 40, tzinfo=tzutc()),
 'message': 'PyTorchJob parallel-train-pytorch2 is running.',
 'reason': 'JobRunning',
 'status': 'True',
 'type': 'Running'}]
Is job running: True


## 학습 Pod 확인

In [8]:
training_client.get_job_pod_names(pytorchjob_name)

['parallel-train-pytorch2-master-0',
 'parallel-train-pytorch2-worker-0',
 'parallel-train-pytorch2-worker-1',
 'parallel-train-pytorch2-worker-2',
 'parallel-train-pytorch2-worker-3']

## 로그 출력

In [9]:
training_client.get_job_logs(pytorchjob_name, container="pytorch")

INFO:root:The logs of pod parallel-train-pytorch2-master-0:
 2024-04-20T03:50:12Z INFO     Added key: store_based_barrier_key:1 to store for rank: 0
2024-04-20T03:50:12Z INFO     Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 5 nodes.
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [01:10<00:00, 373109.10it/s] 
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 110908.43it/s]
Extracting data/FashionMNIST/r

## Training Job 삭제

In [10]:
training_client.delete_pytorchjob(pytorchjob_name)

INFO:root:PyTorchJob kbm-u-admin-jin/parallel-train-pytorch2 has been deleted
