tensorflow1.15-keras 多标签 xception训练与测试

本任务是对人脸属性的性别(female,male)与年龄(children,young,adult,older)分类

xception可以用官方提供的,这里是自己搭的,参考别人的。

这里的主要可以学习的是自己写数据生成器:data_generator, generator=train_gen.get_mini_batch(transform = True)

数据增强脚本

transform_gao.py

import numpy as np
import random
#import torchvision.transforms as transforms
from PIL import Image, ImageFilter
import skimage


class RandomBlur(object):
    def __init__(self, prob=0.5, radius=2):
        self.prob = prob
        self.radius = radius

    def __call__(self, img):
        if random.random() > self.prob:
            return img
        radius = random.uniform(0, self.radius)
        filter = [
            ImageFilter.GaussianBlur(radius),
            ImageFilter.BoxBlur(radius),
            ImageFilter.MedianFilter(size=3)
        ]
        img = img.filter(random.choice(filter))
        return img


class RandomNoise(object):
    def __init__(self, prob=0.5, noise=0.01):
        self.prob = prob
        self.noise = noise

    def __call__(self, img):
        if random.random() > self.prob:
            return img
        img = np.array(img)
        mode = [
            lambda x: skimage.util.random_noise(x, 'gaussian', mean=0, var=self.noise),
            lambda x: skimage.util.random_noise(x, 'speckle', mean=0, var=self.noise),
            lambda x: skimage.util.random_noise(x, 's&p', amount=self.noise),
        ]
        img = (random.choice(mode)(img) * 255).astype(np.uint8)
        img = Image.fromarray(img)
        return img


class ResizeWithPad(object):

    def __init__(self, size, delta_h=5, interpolation=Image.BILINEAR):
        self.size = size
        self.delta_h = delta_h
        self.interpolation = interpolation

    def __call__(self, img):
        iw, ih = img.size
        tw, th = self.size

        w = int(round(iw * th / float(ih)))
        x = random.randint(0, max(0, tw - w))

        dh = random.randint(0, self.delta_h)
        h = th - dh
        y = random.randint(0, dh)

        img_resized = img.resize((w, h), self.interpolation)
        img = Image.new('RGB', self.size, (127, 127, 127))
        img.paste(img_resized, (x, y))

        return img


class Normalize(object):

    def __init__(self):
        pass

    def __call__(self, img):
        img = transforms.functional.to_tensor(img)
        img.sub_(0.5).div_(0.5)
        return img


class RandomAspect(object):

    def __init__(self, aspect=(3./4., 4./3.), interpolation=Image.BILINEAR):
        self.aspect = aspect
        self.interpolation = interpolation

    def __call__(self, img):
        w, h = img.size
        aspect = random.uniform(self.aspect[0], self.aspect[1])
        ow = int(w * aspect)
        oh = int(h / aspect)
        return img.resize((ow, oh), self.interpolation)


def randomColor(image):
    """
    对图像进行颜色抖动
    :param image: PIL的图像image
    :return: 有颜色色差的图像image
    """
    from PIL import Image, ImageEnhance, ImageOps, ImageFile
    if random.random() > 0.5:
        return image

    ratio = (0.85,1.25)
    # ratio = (10, 100)
    random_factor_1 = np.random.uniform(*ratio)
    # print('*'*100)
    # print("random_factor_1==",random_factor_1)

    # image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))  ##opencv  to pil
    random_factor = random_factor_1 #random_factor = np.random.randint(1, 30) / 10.  # 随机因子
    color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
    random_factor = random_factor_1#random_factor = np.random.randint(10, 21) / 10. - 1  # 随机因子
    brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
    random_factor = random_factor_1 #random_factor = np.random.randint(10, 21) / 10. - 1  # 随机因1子
    contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
    random_factor = np.random.uniform(*ratio)#random_factor = np.random.randint(0, 31) / 10.  # 随机因子
    img = ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度
    # img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)  ##  pil to  opencv
    return img


def elastic_transform(image, alpha=0.8, sigma=1.25, alpha_affine=0.08, random_state=None):
    """Elastic deformation of images as described in [Simard2003]_ (with modifications).
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
         Convolutional Neural Networks applied to Visual Document Analysis", in
         Proc. of the International Conference on Document Analysis and
         Recognition, 2003.
     Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
    """

    if random.random() < 0.7:
        return image

    from scipy.ndimage.interpolation import map_coordinates
    from scipy.ndimage.filters import gaussian_filter
    import cv2

    if random_state is None:
        random_state = np.random.RandomState(None)

    image = np.array(image) #pil -> opencv
    shape = image.shape
    shape_size = shape[:2]

    # Random affine
    center_square = np.float32(shape_size) // 2
    square_size = min(shape_size) // 3
    pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size])
    pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
    M = cv2.getAffineTransform(pts1, pts2)
    image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    dz = np.zeros_like(dx)

    x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
    indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1))

    mm = map_coordinates(image, indices, order=1, mode='reflect').reshape(shape)
    return Image.fromarray(mm)#opencv -> pil

产生批数据脚本

my_xception_multi_label_data_generator.py

import os
from PIL import Image
import numpy as np
import random
from transform_gao import *

dict_gender = {'f':0,
               'm':1
}

dict_age = {'children':0,
            'young':1,
            'adult':2,
            'older':3
}

def make_one_hot(class_num,label):
    label_gender_one_hot = np.zeros(class_num)
    label_gender_one_hot[label] = 1.0
    label_gender_one_hot = label_gender_one_hot.astype(np.float32)
    return label_gender_one_hot



class data_generator:
    def __init__(self,path_dir,batch_size,resize_=160,class_gender = 2,class_age = 4):
        self.index=0
        self.path_label = []
        self.batch_size=batch_size
        self.class_gender = class_gender
        self.class_age = class_age
        self.resize_ = resize_

        self.path = path_dir
        self.load_path()


    def load_path(self):
        self.path_label = []
        print('*' * 10)
        print('------start---load_path')
        cnt = 0
        for root, dirs, files in os.walk(self.path):
            if 0 == len(files):
                continue
            #print("root=",root)
            dir_name = root.split('/')[-1].split('_')
            label_gender = dict_gender[dir_name[0]]
            label_age = dict_age[dir_name[1]]
            for name in files:
                cnt += 1
                #print(cnt,name)
                if name.endswith(".jpg") or name.endswith(".png") or name.endswith(".jpeg"):
                    img_path = os.path.join(root, name)
                    self.path_label.append({"img_path":img_path,"label_gender":label_gender,"label_age":label_age})

            random.shuffle(self.path_label)

        print('------end---load_path')
        print('*' * 10)


    def get_mini_batch(self,transform = False):
        while True:
            batch_images=[]
            batch_labels_gender=[]
            batch_labels_age = []
            for i in range(self.batch_size):
                #print('add self.index=',self.index)
                if (self.index == len(self.path_label)):
                    self.index = 0
                    random.shuffle(self.path_label)

                image_path = self.path_label[self.index]["img_path"]
                label_gender = self.path_label[self.index]["label_gender"]
                label_age = self.path_label[self.index]["label_age"]

                img = Image.open(image_path)
                if img.mode != 'RGB':
                    img = img.convert('RGB')


                if transform is True:
                    img = randomColor(img)
                    img = RandomNoise(prob=0.5, noise=0.01)(img)
                    img = RandomBlur(prob=0.5, radius=1.5)(img)
                    img = RandomAspect(aspect=(0.93, 1.12))(img) #img = RandomAspect(aspect=(4. / 5., 5. / 4.))(img)
                    img = elastic_transform(img,alpha_affine=0.01)

                img = img.resize((self.resize_, self.resize_))

                batch_images.append(img)

                label_age_one_hot = make_one_hot(self.class_age, label_age)
                label_gender_one_hot = make_one_hot(self.class_gender, label_gender)

                batch_labels_gender.append(label_gender_one_hot)
                batch_labels_age.append(label_age_one_hot)
                self.index += 1

            images_t = []
            for image in batch_images:
                image = np.array(image)  ##numpy
                image = image * 1.0 / 255.0
                #image = image - 127.0 / 127.0
                image = image.astype(np.float32)
                images_t.append(image)

            batch_images = np.array(images_t)
            batch_labels_gender = np.array(batch_labels_gender)
            batch_labels_age = np.array(batch_labels_age)
            yield batch_images, [batch_labels_gender,batch_labels_age]




if __name__ == '__main__':
    # class data_generator:
    #     def __init__(self, path_dir, batch_size,, resize_):
    path_dir = "/data_2/big-data/compete/20200323/src_data/age_gender1/test/"
    batch_size = 10
    resize_ = 160
    data_gen = data_generator(path_dir,batch_size,resize_)
    data_gen.get_mini_batch()

训练脚本


import os
import sys
import tensorflow as tf
import time
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda,Dropout

from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras.callbacks import ModelCheckpoint
from my_xception_multi_label_data_generator import *
print(tf.__version__)
print(sys.version_info)


SIZE = 160
BATCH_SIZE = 5
EPOCH = 100000

class_gender = 2
class_age = 4
path_dir_train = "/data_2/big-data/compete/20200323/src_data/age_gender1/test/"
path_dir_test = "/data_2/big-data/compete/20200323/src_data/age_gender1/test/"

def cbam(inputs):
    inputs_channels = int(inputs.shape[-1])
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Dense(int(inputs_channels / 4))(x)
    x = tf.keras.layers.Activation("relu")(x)
    x = tf.keras.layers.Dense(int(inputs_channels))(x)
    x = tf.keras.layers.Activation("softmax")(x)
    x = tf.keras.layers.Reshape((1, 1, inputs_channels))(x)
    x = tf.keras.layers.multiply([inputs, x])
    return x


def Xception(input_shape=(SIZE, SIZE, 3), classes=8):
    """Instantiates the Xception architecture.
     Note that the default input image size for this model is 299x299.
    # Arguments

        input_shape: optional shape tuple, only to be specified
            if `include_top` is False (otherwise the input shape
            has to be `(299, 299, 3)`.
            It should have exactly 3 inputs channels,
            and width and height should be no smaller than 71.
            E.g. `(150, 150, 3)` would be one valid value.

        classes: optional number of classes to classify images
            into, only to be specified if `include_top` is True,
            and if no `weights` argument is specified.
    # Returns
        A Keras model instance.
    # Raises
        ValueError: in case of invalid argument for `weights`,
            or invalid input shape.
        RuntimeError: If attempting to run this model with a
            backend that does not support separable convolutions.
    """

    img_input = layers.Input(shape=input_shape)

    channel_axis = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1

    x = layers.Conv2D(32, (3, 3),
                      strides=(2, 2),
                      use_bias=False,
                      name='block1_conv1')(img_input)
    x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
    x = layers.Activation('relu', name='block1_conv1_act')(x)
    x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
    x = layers.Activation('relu', name='block1_conv2_act')(x)

    residual = layers.Conv2D(128, (1, 1),
                             strides=(2, 2),
                             padding='same',
                             use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.SeparableConv2D(128, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block2_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block2_sepconv2_act')(x)
    x = layers.SeparableConv2D(128, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block2_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3),
                            strides=(2, 2),
                            padding='same',
                            name='block2_pool')(x)
    x = Dropout(0.25)(x)###############
    x = layers.add([x, residual])

    residual = layers.Conv2D(256, (1, 1), strides=(2, 2),
                             padding='same', use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block3_sepconv1_act')(x)
    x = layers.SeparableConv2D(256, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block3_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block3_sepconv2_act')(x)
    x = layers.SeparableConv2D(256, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block3_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3), strides=(2, 2),
                            padding='same',
                            name='block3_pool')(x)
    x = Dropout(0.25)(x)
    x = layers.add([x, residual])

    residual = layers.Conv2D(728, (1, 1),
                             strides=(2, 2),
                             padding='same',
                             use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block4_sepconv1_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block4_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block4_sepconv2_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block4_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3), strides=(2, 2),
                            padding='same',
                            name='block4_pool')(x)
    x = Dropout(0.25)(x)
    x = layers.add([x, residual])

    for i in range(8):
        residual = x
        prefix = 'block' + str(i + 5)

        x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv1')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv1_bn')(x)
        x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv2')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv2_bn')(x)
        x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv3')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv3_bn')(x)

        x = layers.add([x, residual])

    residual = layers.Conv2D(1024, (1, 1), strides=(2, 2),
                             padding='same', use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block13_sepconv1_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block13_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block13_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block13_sepconv2_act')(x)
    x = layers.SeparableConv2D(1024, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block13_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block13_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3),
                            strides=(2, 2),
                            padding='same',
                            name='block13_pool')(x)
    x = Dropout(0.1)(x)
    x = layers.add([x, residual])

    x = layers.SeparableConv2D(1536, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block14_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block14_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block14_sepconv1_act')(x)

    x = layers.SeparableConv2D(2048, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block14_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block14_sepconv2_bn')(x)
    x = layers.Activation('relu', name='block14_sepconv2_act')(x)
    # x=cbam(x)
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    # x = layers.Dense(classes, activation='softmax', name='predictions')(x)


    x1 = tf.keras.layers.Dense(512, activation='relu')(x)
    x1 = tf.keras.layers.Dense(256, activation='relu')(x1)
    out_age = tf.keras.layers.Dense(4, activation='softmax', name='out_age')(Dropout(0.5)(x1))

    x2 = tf.keras.layers.Dense(512, activation='relu')(x)
    out_gender = tf.keras.layers.Dense(2, activation='softmax', name='out_gender')(Dropout(0.5)(x2))

    predictions = [out_gender,out_age]

    model = tf.keras.models.Model(inputs=img_input, outputs=predictions,name='xception-multi-label')
    return model

model = Xception()
model.summary()


model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss={'out_age':'categorical_crossentropy',
                    'out_gender':'categorical_crossentropy'},
              loss_weights={
                  'out_age': 1.0,
                  'out_gender': 0.25
              },
              metrics=['accuracy'])



train_gen=data_generator(path_dir_train,BATCH_SIZE,SIZE,class_gender,class_age)
test_gen=data_generator(path_dir_test,BATCH_SIZE,SIZE,class_gender,class_age)

num_train = len(train_gen.path_label)
num_test = len(test_gen.path_label)
print('*'*100)
print('num_train=',num_train)
print('num_test=',num_test)
print('*'*100)


# x,y_gender,y_age=next(train_gen.get_mini_batch(transform = True))
# print(x.shape)
# print(y_gender.shape)
# print(y_age.shape)
#
# import cv2
# for i in range(0):
#     print("****start****"*10)
#     x, y_gender,y_age = next(train_gen.get_mini_batch(transform = True))
#     for j in range(BATCH_SIZE):
#         print(y_gender[j])
#         print(y_age[j])
#         cv2.imshow("1", x[j][:,:,::-1])
#         cv2.waitKey(0)
#     print("@@@@@@end@@@@@@" * 10)


filepath = "./model/multi-label-model_{epoch:03d}-{val_out_gender_acc:.4f}-{val_out_age_acc:.4f}.h5" #避免文件名称重复
checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_acc', verbose=1,
                             save_best_only=False, mode='max')


history = model.fit_generator(
      generator=train_gen.get_mini_batch(transform = True),
      steps_per_epoch=num_train // BATCH_SIZE,  # 2000 images = batch_size * steps
      epochs=EPOCH,
      validation_data=test_gen.get_mini_batch(),
      validation_steps=num_test // BATCH_SIZE,  # 1000 images = batch_size * steps
      verbose=1,
      callbacks=[checkpoint])

测试脚本


import os
import sys
import tensorflow as tf
import time
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from keras.preprocessing import image
import numpy as np
import cv2
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda,Dropout

from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras.callbacks import ModelCheckpoint
from my_xception_multi_label_data_generator import *






print(tf.__version__)
print(sys.version_info)




SIZE = 160
BATCH_SIZE = 5
EPOCH = 100000


class_gender = 2
class_age = 4
path_dir_train = "/data_2/big-data/compete/20200323/age_gender-aug-my/test/"
path_dir_test = "/data_2/big-data/compete/20200323/age_gender-aug-my/test/"






def cbam(inputs):
    inputs_channels = int(inputs.shape[-1])
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Dense(int(inputs_channels / 4))(x)
    x = tf.keras.layers.Activation("relu")(x)
    x = tf.keras.layers.Dense(int(inputs_channels))(x)
    x = tf.keras.layers.Activation("softmax")(x)
    x = tf.keras.layers.Reshape((1, 1, inputs_channels))(x)
    x = tf.keras.layers.multiply([inputs, x])
    return x


def Xception(input_shape=(SIZE, SIZE, 3), classes=8):
    """Instantiates the Xception architecture.
     Note that the default input image size for this model is 299x299.
    # Arguments

        input_shape: optional shape tuple, only to be specified
            if `include_top` is False (otherwise the input shape
            has to be `(299, 299, 3)`.
            It should have exactly 3 inputs channels,
            and width and height should be no smaller than 71.
            E.g. `(150, 150, 3)` would be one valid value.

        classes: optional number of classes to classify images
            into, only to be specified if `include_top` is True,
            and if no `weights` argument is specified.
    # Returns
        A Keras model instance.
    # Raises
        ValueError: in case of invalid argument for `weights`,
            or invalid input shape.
        RuntimeError: If attempting to run this model with a
            backend that does not support separable convolutions.
    """

    img_input = layers.Input(shape=input_shape)

    channel_axis = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1

    x = layers.Conv2D(32, (3, 3),
                      strides=(2, 2),
                      use_bias=False,
                      name='block1_conv1')(img_input)
    x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
    x = layers.Activation('relu', name='block1_conv1_act')(x)
    x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
    x = layers.Activation('relu', name='block1_conv2_act')(x)

    residual = layers.Conv2D(128, (1, 1),
                             strides=(2, 2),
                             padding='same',
                             use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.SeparableConv2D(128, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block2_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block2_sepconv2_act')(x)
    x = layers.SeparableConv2D(128, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block2_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3),
                            strides=(2, 2),
                            padding='same',
                            name='block2_pool')(x)
    x = Dropout(0.25)(x)###############
    x = layers.add([x, residual])

    residual = layers.Conv2D(256, (1, 1), strides=(2, 2),
                             padding='same', use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block3_sepconv1_act')(x)
    x = layers.SeparableConv2D(256, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block3_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block3_sepconv2_act')(x)
    x = layers.SeparableConv2D(256, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block3_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3), strides=(2, 2),
                            padding='same',
                            name='block3_pool')(x)
    x = Dropout(0.25)(x)
    x = layers.add([x, residual])

    residual = layers.Conv2D(728, (1, 1),
                             strides=(2, 2),
                             padding='same',
                             use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block4_sepconv1_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block4_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block4_sepconv2_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block4_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3), strides=(2, 2),
                            padding='same',
                            name='block4_pool')(x)
    x = Dropout(0.25)(x)
    x = layers.add([x, residual])

    for i in range(8):
        residual = x
        prefix = 'block' + str(i + 5)

        x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv1')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv1_bn')(x)
        x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv2')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv2_bn')(x)
        x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
        x = layers.SeparableConv2D(728, (3, 3),
                                   padding='same',
                                   use_bias=False,
                                   name=prefix + '_sepconv3')(x)
        x = layers.BatchNormalization(axis=channel_axis,
                                      name=prefix + '_sepconv3_bn')(x)

        x = layers.add([x, residual])

    residual = layers.Conv2D(1024, (1, 1), strides=(2, 2),
                             padding='same', use_bias=False)(x)
    residual = layers.BatchNormalization(axis=channel_axis)(residual)

    x = layers.Activation('relu', name='block13_sepconv1_act')(x)
    x = layers.SeparableConv2D(728, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block13_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block13_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block13_sepconv2_act')(x)
    x = layers.SeparableConv2D(1024, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block13_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block13_sepconv2_bn')(x)

    x = layers.MaxPooling2D((3, 3),
                            strides=(2, 2),
                            padding='same',
                            name='block13_pool')(x)
    x = Dropout(0.1)(x)
    x = layers.add([x, residual])

    x = layers.SeparableConv2D(1536, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block14_sepconv1')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block14_sepconv1_bn')(x)
    x = layers.Activation('relu', name='block14_sepconv1_act')(x)

    x = layers.SeparableConv2D(2048, (3, 3),
                               padding='same',
                               use_bias=False,
                               name='block14_sepconv2')(x)
    x = layers.BatchNormalization(axis=channel_axis, name='block14_sepconv2_bn')(x)
    x = layers.Activation('relu', name='block14_sepconv2_act')(x)
    # x=cbam(x)
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    # x = layers.Dense(classes, activation='softmax', name='predictions')(x)


    x1 = tf.keras.layers.Dense(512, activation='relu')(x)
    x1 = tf.keras.layers.Dense(256, activation='relu')(x1)
    out_age = tf.keras.layers.Dense(4, activation='softmax', name='out_age')(Dropout(0.5)(x1))

    x2 = tf.keras.layers.Dense(512, activation='relu')(x)
    out_gender = tf.keras.layers.Dense(2, activation='softmax', name='out_gender')(Dropout(0.5)(x2))

    predictions = [out_gender,out_age]

    model = tf.keras.models.Model(inputs=img_input, outputs=predictions,name='xception-multi-label')
    return model

model = Xception()
model.summary()
model.load_weights("/data_1/Yang/project_new/2020/tf_study/tf_xception/multi_xception/model/multi-label-model_005.h5")
#

dict_gender = {'f':0,
               'm':1
}

dict_age = {'children':0,
            'young':1,
            'adult':2,
            'older':3
}

dict_label_tijiao = {"f_children":"0",
              "f_young":"1",
              "f_adult":"2",
              "f_older":"3",
              "m_children":"4",
              "m_young":"5",
              "m_adult":"6",
              "m_older":"7"
              }

map_predict_code2label = {"00":"f_children",
                          "01":"f_young",
                          "02":"f_adult",
                          "03":"f_older",
                          "10":"m_children",
                          "11":"m_young",
                          "12":"m_adult",
                          "13":"m_older"}

root_dir_test =  "/data_2/big-data/compete/20200323/src_data/test-tijiao/"

with open(root_dir_test + 'result.txt','w')as fw:
    for root, dirs, files in os.walk(root_dir_test):
        if 0 == len(files):
            continue
        for img_name_ in files:
            if img_name_.endswith(".jpg") or img_name_.endswith(".jpeg") or img_name_.endswith(".png"):
                pos = img_name_.find(".")
                name = img_name_[0:pos]
                img_path = os.path.join(root,img_name_)
                img = image.load_img(img_path, target_size=(SIZE, SIZE))
                img = image.img_to_array(img) / 255.0
                #print("img.shape=",img.shape[:2])
                img = np.expand_dims(img, axis=0)  # 为batch添加第四维
                predictions = model.predict(img)
                label_gender = np.argmax(predictions[0], axis=1)
                label_age = np.argmax(predictions[1], axis=1)
                code = str(label_gender[0]) + str(label_age[0])
                label_describe = map_predict_code2label[code]
                label_tijiao_val = dict_label_tijiao[label_describe]
                content = name + " " + label_tijiao_val
                fw.write(content + '\n')

                # cv2.imshow("img", img[0][:, :, ::-1])
                # cv2.waitKey(0)