-
[논문 구현] U-Net 구현하기 1탄 - 학습 코드 짜기Computer Vision 모델 2025. 4. 9. 22:26
논문도 다 읽었으니 UNet 을 직접 pytorch로 구현해보는 실습을 해볼것이다..!
예전에 작성했던 논문 리뷰 링크!
[논문 리뷰] U-Net 꼼꼼한 리뷰
AbstractDeep Neural Network의 성공적인 학습을 위해서는 많은 양의 annotated training sample들이 필요하다.이 paper에서는 annotated sample들을 더욱더 효율적으로 활용하기 위하여 Data augmentation기반의 새로운 ne
whatisworth.tistory.com
추론하는 코드를 짤때는 사용자가 직접 테스트 해볼 수 있도록 streamlit 같은 시각화 도구를 사용하여 데모페이지도 한번 만들어보려고 하는 계획을 가지고 있다 ㅎㅎ
나는 항상 코딩할때 구조부터 짜고 들어간다. 안그러면 코드가 정신없어져서 스파게티 코드가 되어버리기 때문에..ㅎ.ㅎ
1. 코드 구조
. ├── README.md ├── configs │ └── config.yaml ├── input ├── main.py ├── model │ └── unet.py ├── output ├── requirements.txt ├── utils └── weights
- configs : 학습한 weight 경로, 데이터 경로 등 각종 파라미터 저장용
- input : 학습 데이터 저장 경로
- main.py : 추론코드
- model : 모델 구조
- output : 모델이 추론한 결과물들 모아놓는 곳
- requirements.txt : 각종 패키지 설치용
- utils : 각종 시각화,, 기타 잡다한 것들
- weights : 학습한 weight 저장 경로
2. 모델 구조 짜기
(1) Contracting Path
첫번째로 해야할 것은 모델의 구조를 짜는 것이다.
먼저 contracting path, 즉 U-net 구조에서 압축하는 부분을 짤건데 요 부분 같은 경우는 3x3 convolution + relu 가 2번 반복적으로 나타나기 때문에 하나의 block으로 짜주면 보기에 좋다.
또, 아무래도 반복적인 패턴이 있기 때문에, Sequential 을 사용하여 짜주면 구조화 하는데 좋다!
class ConvolutionBlock(nn.Module): def __init__(self, in_channels, out_channels, padding=0, downsample=None): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, padding), nn.BatchNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.1), nn.Conv2d(out_channels, out_channels, padding), nn.BatchNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) def forward(self, x): out = self.conv(x) return out
그리고 UNet 모델 class에 윗부분을 추가해줄건데, 여기서 중요한 점!! 앞에서 짠 convolution block의 경우, "class"이기 때문에 UNet class에서 init에서 정의하지 않고 들어가버리면 학습이 이상하게 되어버린다
예를 들어,
class UNet(nn.Module): def __init__(self): super().__init__() self.pool = nn.MaxPool2d(2,2) def forward(self, x): out = ConvolutionBlock(3, 64, padding=1)
이런 식으로 forward 함수 밑에 class로 정의한 block을 써버리게 되면 매번 클래스를 새로 인스턴스화 하기 때문에 올바른 학습이 되지 않는다.
따라서,
class UNet(nn.Module): def __init__(self): super().__init__() self.pool = nn.MaxPool2d(2,2) self.down_conv1 = ConvolutionBlock(3, 64, padding=1) def forward(self, x): out = self.down_conv1(x)
이런식으로 init에서 먼저 초기화를 해준 후, 가져다 써야한다.
(2) skip connection
skip connection을 위해, contracting path 4번으로부터 나온 결과물을 각각 저장해준다. 이렇게!
def forward(self, x): x1 = self.down_conv1(x) x2 = self.down_conv2(self.pool(x1)) x3 = self.down_conv3(self.pool(x2)) x4 = self.down_conv4(self.pool(x3))
contracting path에서 skip connection으로 expansion path로부터 나온 결과물이랑 concatenate 시킬 때 사이즈가 안맞다... 이를테면 x4는 [1, 512, 64, 64] 인데, 합쳐야할 대상은 [1, 512, 56, 56]이다. 따라서 center crop 과정을 거쳐야한다.
# skip connection을 위한 center crop def center_crop(enc_feat, target_size): _, _, h, w = enc_feat.shape th, tw = target_size x1 = (w - tw) // 2 y1 = (h - th) // 2 return enc_feat[:, :, y1:y1+th, x1:x1+tw]
(3) expansion path
decoder 부분에 해당하는 expansion path는 특이하게 up convolution, 즉 합성곱 연산을 거꾸로 하는 연산이 필요하다.
pytorch nn에 기본적으로 잘 짜져있는 nn.ConvTranspose2d 를 사용한다.
self.up_transpose1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
이런식으로 총 4개를 필터 사이즈에 맞춰서 작성해주면 된다!
그리고 그림에서 빨간색으로 표시된 부분은 앞서 작성한 convolution block과 구조가 같으므로 재사용해준다.
(4) final layer
final layer는 그림에서도 보다시피 1x1 conv layer를 통과시켜야한다. class 개수에 맞추어 다음과 같이 작성해준다.
self.final = nn.Conv2d(64, num_classes, kernel_size=1)
그럼 완성!!
전체 코드는 다음과 같다!
import torch import os from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms # GPU 장치 선언 device = "cuda" if torch.cuda.is_available() else "cpu" print(device) # skip connection을 위한 center crop def center_crop(enc_feat, target_size): _, _, h, w = enc_feat.shape th, tw = target_size x1 = (w - tw) // 2 y1 = (h - th) // 2 return enc_feat[:, :, y1:y1+th, x1:x1+tw] # convolution block 정의하기 - 3x3 conv + relu 2번 class ConvolutionBlock(nn.Module): def __init__(self, in_channels, out_channels, padding=0, kernel_size=3): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), nn.BatchNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.1), nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), nn.BatchNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) def forward(self, x): out = self.conv(x) return out class UNet(nn.Module): def __init__(self, num_classes=1): super().__init__() self.pool = nn.MaxPool2d(2,2) self.down_conv1 = ConvolutionBlock(3, 64) self.down_conv2 = ConvolutionBlock(64, 128) self.down_conv3 = ConvolutionBlock(128, 256) self.down_conv4 = ConvolutionBlock(256, 512) self.down_conv5 = ConvolutionBlock(512, 1024) self.up_transpose1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.up_conv1 = ConvolutionBlock(1024, 512) self.up_transpose2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.up_conv2 = ConvolutionBlock(512, 256) self.up_transpose3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.up_conv3 = ConvolutionBlock(256, 128) self.up_transpose4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.up_conv4 = ConvolutionBlock(128, 64) self.final = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x): x1 = self.down_conv1(x) x2 = self.down_conv2(self.pool(x1)) x3 = self.down_conv3(self.pool(x2)) x4 = self.down_conv4(self.pool(x3)) out = self.pool(x4) out = self.down_conv5(out) out = self.up_transpose1(out) out = torch.cat((out, center_crop(x4, out.shape[2:])),dim=1) out = self.up_conv1(out) out = self.up_transpose2(out) out = torch.cat((out, center_crop(x3, out.shape[2:])),dim=1) out = self.up_conv2(out) out = self.up_transpose3(out) out = torch.cat((out, center_crop(x2, out.shape[2:])),dim=1) out = self.up_conv3(out) out = self.up_transpose4(out) out = torch.cat((out, center_crop(x1, out.shape[2:])),dim=1) out = self.up_conv4(out) out = self.final(out) return out
'Computer Vision 모델' 카테고리의 다른 글
[논문 리뷰] U-Net 꼼꼼한 리뷰 (0) 2025.03.22