# Kaggle小技巧:TTA（test time augmentation）测试时加强

2020年10月25日 阅读数：674

### 1.原理简述

TTA：test time augmentation

### 2.代码实现

``````***

augment:为一张图片制做TTA
batch_augment:为多张图片批量制做TTA
deaugment_boxes:将TTA预测框返回到图像的原始状态中
***

class BaseTTA:
image_size = 512

def augment(self, image):
raise NotImplementedError

def batch_augment(self, images):
raise NotImplementedError

def deaugment_boxes(self, boxes):
raise NotImplementedError

# 水平翻转
class TTAHorizontalFlip(BaseTTA):
def augment(self, image):
return image.flip(1)

def batch_augment(self, images):
return images.flip(2)

def deaugment_boxes(self, boxes):
boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]]
return boxes

# 垂直翻转
class TTAVerticalFlip(BaseTTA):

def augment(self, image):
return image.flip(2)

def batch_augment(self, images):
return images.flip(3)

def deaugment_boxes(self, boxes):
boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
return boxes

# 旋转90°
class TTARotate90(BaseTTA):
def augment(self, image):

def batch_augment(self, images):

def deaugment_boxes(self, boxes):
res_boxes = boxes.copy()
res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
res_boxes[:, [1,3]] = boxes[:, [2,0]]
return res_boxes

# 组合使用
class TTACompose(BaseTTA):
def __init__(self, transforms):
self.transforms = transforms

def augment(self, image):
for transform in self.transforms:
image = transform.augment(image)
return image

def batch_augment(self, images):
for transform in self.transforms:
images = transform.batch_augment(images)
return images

def prepare_boxes(self, boxes):
result_boxes = boxes.copy()
result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
return result_boxes

def deaugment_boxes(self, boxes):
for transform in self.transforms[::-1]:
boxes = transform.deaugment_boxes(boxes)
return self.prepare_boxes(boxes)
``````