Skip to content

The-Pillars-of-Creation/NTO-second-tour

Repository files navigation

Vitya model

Наша нейросеть Витя умеет различать видео и классифицировать их на 9 категорий

label id
water 0
car 1
cloud 2
food 3
flower 4
dance 5
animal 6
sunset 7
fire 8

Мы использовали ResNet50V2 для покадравой классификации видео и давали тэг по наиболее частому тэгу кадров.

class VityaModel:
    def __init__(self) -> None:
        augmentaion_layer = Sequential([
            RandomFlip("horizontal", seed=RANDOM_SEED),
            RandomRotation(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomZoom(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomHeight(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomWidth(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            Rescaling(1 / 255.)
        ])

        augmentaion_layer
        base_model = tf.keras.applications.ResNet50V2(include_top=False, weights=None)

        base_model.trainable = False

        input_layer = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
        x = augmentaion_layer(input_layer)
        x = base_model(x, training=False)
        x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling2d")(x)
        output_layer = Dense(len(TAGS), activation=softmax, name="output_layer")(x)

        model = tf.keras.Model(input_layer, output_layer)
        model.compile(
            loss=tf.keras.losses.CategoricalCrossentropy(),
            optimizer=tf.keras.optimizers.Adam(),
            metrics=["accuracy"]
        )
        
        model_checkpoint_path = "./checkpoints/vitya_weights"
        model.load_weights(model_checkpoint_path)
        print("Compiled VityaModel")

        self.model = model

Обучение модели train.ipynb

history_1 = model_1.fit(
    train_data,
    epochs=EPOCHS,
    steps_per_epoch=len(train_data),
    validation_data=test_data,
    validation_steps=int(0.15 * len(test_data)),
    callbacks=[
        tf.keras.callbacks.CSVLogger("history/history.csv"),
        checkpoint_callback,
        create_tensorboard_callback(dir_name="tensorboard", experiment_name="vitya")
    ]
)
Файл Назначение
solution.py итоговое решение для отправки в систему НТО
solution.ipynb тестировка модели на тестовых фотографиях
train.ipynb обучение модели
first_try.ipynb первая попытка обучения модели
dataload.py первоначальная загрузка данных для pytorch, не используется