IOU(Intersection over union)

이번 포스팅에서는 많은 Object Detection 논문에서 사용하는 IOU(Intersection over union) 에 대해 알아보자. 참고로 IOU는 jaccard overlap 이라고도 불린다.


위의 그림을 보면 Prediction은  Ground truth와 꽤 겹치는 듯 하면서도 만족스럽지는 못한 것을 알 수있다. 이러한 차이를 어떤 방법으로 Prediction과 Ground trutt를 비교하면서 평가를 할 수 있을까? 

Image 1 Image 2


바로 왼쪽 그림의 두 박스의 교집합과 오른쪽 그림의 합집합을 이용을 하면 prediction이 얼마나 ground truth와 겹쳐 있는지 수치적으로 비교할 수 있다.


즉 IOU는 이름에서부터 알 수 있듯이 두 박스(prediction and ground truth)의 교집합을 합집합으로 나눈것이고 이 값은  0~1사이의 값을 가지게 된다. 보통 Object Detection에서 IOU 0.5를 기준으로 그 이상이면 해당 prediction은 어떤 물체를 가리키고 있으며 좀 더 높은 iou를 내도록 계속해서 학습시키려 하고 0.5 이하라면 배경화면 즉 물체가 없다고 간주하고 제거를 하는 방향으로 사용한다. (논문마다 threshold 등등 조금씩 다름)


위와 같은 그림에서 빨간색의 ground truth와 파란색의 prediction 과의  intersection을 어떻게 구할 수 있을까?

(이미지 좌표값은 좌상단이 0,0이고 우측으로갈수록 x값이 커지게 되고 아래로 갈수록 y값이 커지게 되는 것을 유의하자) 


먼저 intersection의 좌상단 좌표를 알아내야 한다. 좌상단의 좌표값은 두 박스중에 x1값과 y1값이 큰 값을 각각 가져오면 된다. 위 그림에서는 x1,y1 둘다 파란색 박스보다 빨간색 박스가 크기때문에 ground truthd의 x1,y1값이 intersection의 좌 상단 좌표값이 된다.


intersection의 우 하단은 좌상단의 반대로 진행하면 된다. 두 박스의 x2와 y2를 비교하여 더 작은 값을 각각 가져오면 되고 위 그림에서는 파란색 박스 즉 prediction의 x2, y2의 좌표값이 더 작으므로 이것들이 intersection의 우 하단 좌표값이 된다.

과연 위와같은 방식이 다른 박스에서도 일반화가 되는지 좀 더 복잡한 그림으로 확인해보자.

Image 1 Image 2


왼쪽 그림의 intersection을 구해보자. 먼저 두 박스의 좌 상단 좌표값을 비교해보면 x1은 빨간색 박스가 더 크고 y1은 파란색 박스가 더 큰것을 볼수 있다. 그리고 우 하단의 좌표값을 비교해보면 x좌표는 파란색 박스의 x2가 더 작고 y 좌표는 빨간색 좌표의 y2가 더 작으므로 intersection의 좌표는 오른쪽 그림과 같을 것이다.

(intersection과는 다르게 Union은 간단하게 구할 수 있기 때문에 코드에서 설명할 예정이다.)


구현코드

import torch

def intersection_over_union(boxes_preds, boxes_labels, box_format='midpoint'):
    '''
    Calculates intersection over union
    
    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes  (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct labels of Bounding Boxes  (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
    
    Returns:
        tensor: Intersection over union for all examples
    ''' 
    
    if box_format == 'midpoint':
        box1_x1 = boxes_preds[:,0:1] - boxes_preds[:,2:3] / 2
        box1_y1 = boxes_preds[:,1:2] - boxes_preds[:,3:4] / 2
        box1_x2 = boxes_preds[:,0:1] + boxes_preds[:,2:3] / 2
        box1_y2 = boxes_preds[:,1:2] + boxes_preds[:,3:4] / 2

        box2_x1 = boxes_labels[:,0:1] - boxes_labels[:,2:3] / 2
        box2_y1 = boxes_labels[:,1:2] - boxes_labels[:,3:4] / 2
        box2_x2 = boxes_labels[:,0:1] + boxes_labels[:,2:3] / 2
        box2_y2 = boxes_labels[:,1:2] + boxes_labels[:,3:4] / 2
    
    elif box_format == 'corners':   
        box1_x1 = boxes_preds[:,0:1]
        box1_y1 = boxes_preds[:,1:2]
        box1_x2 = boxes_preds[:,2:3]
        box1_y2 = boxes_preds[:,3:4]

        box2_x1 = boxes_labels[:,0:1]
        box2_y1 = boxes_labels[:,1:2]
        box2_x2 = boxes_labels[:,2:3]
        box2_y2 = boxes_labels[:,3:4]
        
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)
    
    # .clamp(0) is for the case when they do not intersect
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    
    box1_area = abs((box1_x2 - box1_x1) * (box1_y1 - box1_y2))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y1 - box2_y2))
    
    return intersection / (box1_area + box2_area - intersection + 1e-6)

bounding box의 좌표값의 형태는 크게 midpoint타입과 coner타입 두가지로 나뉜다. midpoint는 (cx, cy, w, h)의 좌표값을 가지고 corner타입은 (x1, y1, x2, y2)의 좌표값을 가진다.

위 함수에서는 boxes_preds[: , 0:1]은 모든 box에 대해 x1값만 추출하는 부분인데 굳이 boxes_preds[: , 0]을 안쓰고 boxes_preds[: , 0:1]을 쓴 이유는 차원을 (batchsize, 1) 로 유지하기 위함이다. boxes_preds[: , 0] 만 쓰게되면 (batchsize) 와 같이 1차워 벡터가 된다.

intersection 변수를 선언할때 torch.clamp(0)을 사용한 이유는 두 박스가 전혀 겹치지 않을때 intersection을 0으로 만들기 위해 쓴다. clamp(0)을 사용하면 0이하의값 즉 negative값은 모조리 0으로 만들어 준다. (두 박스가 전혀 겹치지 않았을때 x2-x1 혹은 y2 -y1을 할때 0이나오거나 negative가 나오는데 이 negative값을 0으로 만들어 주기 위한 목적)

return에서 (box1_area + box2_area - intersection)이 바로 Union이 되고 뒤에 1e-6은 stability(분모가 0이되는 것을 방지)를 위한 작은 상수값이다.


End

이번 포스팅에서는 Object Detection에서 사용되는 IOU의 개념에 대해 알아보았다. IOU 계산에서 중요한 것은 입력되는 박스의 형태가 midpoint인지 corner 형태 인지에 대해 확실히 알고 그에 맞는 코딩을 해야 한다는 것이다.

reference : www.youtube.com/watch?v=XXYG5ZWtjj0

업데이트:

댓글남기기