众所周知,AE是一个做特效的软件,它可以很方便地跟踪视频中的点位,甚至做三维跟踪。

又众所周知,很多时候练模型需要标注视频。

邪修数据标注···

任务目标

  • 我需要训练一个能够检测一个沙盘四个角点用于矫正的模型。

于是,

  • 我需要拍摄一段视频,标注好四个角点。

数据标注

AE的跟踪器可以很方便地跟踪点。

将一个或多个视频放入一个合成,每一个点作为一个蒙版。

(上图在同一个合成中的两段视频都是训练集)

点存在就有关键帧,点不存在的位置没有关键帧,方便最后导出。

模型训练需要验证集,需要按照上述步骤再开一个新的合成作为验证集。

因为视频每帧之间太过相近了,所以不能从训练集中抽一些作为验证集,需要单独拍一段,单独标注。

数据导出

众所周知,Adobe的软件和软件用的插件都是JS写的,而且AE是支持直接运行JS脚本的。

所以我们可以写(AI)一段JS脚本,用于导出数据。

// 将所有层的相同索引蒙版合并导出为 {index}.csv
// 每个文件包含: frame, x, y
// 上层覆盖下层,导出时只包含关键帧点

(function () {
    app.beginUndoGroup("Export Mask First Point Keyframes By Index");

    var comp = app.project.activeItem;
    if (!(comp && comp instanceof CompItem)) {
        alert("请先打开一个合成喵~");
        return;
    }

    // 选择保存文件夹
    var folder = Folder.selectDialog("请选择保存文件夹喵~");
    if (!folder) {
        return;
    }

    var fps = comp.frameRate;
    var maskData = {}; // { index: { frame: {x, y, layerIndex} } }

    // 遍历所有图层,从下往上(后面覆盖前面)
    for (var li = 1; li <= comp.numLayers; li++) {
        var layer = comp.layer(li);
        if (!layer.mask || layer.mask.numProperties < 1) continue;

        for (var mi = 1; mi <= layer.mask.numProperties; mi++) {
            var mask = layer.mask(mi);
            var maskPath = mask.property("maskPath");
            if (!maskPath || maskPath.numKeys === 0) continue;

            if (!maskData[mi]) maskData[mi] = {};

            for (var k = 1; k <= maskPath.numKeys; k++) {
                var t = maskPath.keyTime(k);
                var shape = maskPath.keyValue(k);
                if (shape.vertices.length === 0) continue;

                var vertex = shape.vertices[0];
                var frame = Math.round(t * fps);

                // 计算合成坐标(加上图层的 transform 偏移)
                var pos = [vertex[0], vertex[1]];
                var transform = layer.property("ADBE Transform Group");
                if (transform) {
                    var anchor = transform.property("ADBE Anchor Point").value;
                    var position = transform.property("ADBE Position").value;
                    pos[0] += (position[0] - anchor[0]);
                    pos[1] += (position[1] - anchor[1]);
                }

                // 如果该帧已存在,则由上层覆盖
                maskData[mi][frame] = {
                    x: pos[0],
                    y: pos[1],
                    layerIndex: li
                };
            }
        }
    }

    // 写入 CSV 文件
    for (var index in maskData) {
        var file = new File(folder.fsName + "/" + index + ".csv");
        file.encoding = "UTF-8";
        file.open("w");
        file.writeln("frame,x,y");

        var frames = [];
        for (var f in maskData[index]) frames.push(parseInt(f));
        frames.sort(function (a, b) { return a - b; });

        for (var i = 0; i < frames.length; i++) {
            var frame = frames[i];
            var d = maskData[index][frame];
            file.writeln(frame + "," + d.x.toFixed(3) + "," + d.y.toFixed(3));
        }

        file.close();
    }

    alert("导出完成喵~\n文件已保存到:" + folder.fsName);
    app.endUndoGroup();
})();

于是得到这样子的数据:

frame,x,y
0,1299.500,892.188
1,1299.500,892.188
2,1309.252,891.400
3,1319.506,890.490
4,1331.185,889.569
5,1343.930,888.683
6,1357.215,887.278
7,1371.780,886.319
...

数据集生成

根据Ultralytics的数据集说明

从AE中,将标注用的合成直接导出成PNG序列。

训练集放入 dataset/images/train

验证集放入 dataset/images/val

使用下面的脚本,根据前面导出的数据生成标注标签。

"""
数据集生成脚本
根据AE导出的CSV标注文件,生成YOLO格式的标签文件并创建数据集配置文件data.yaml
"""

import tkinter as tk
from tkinter import filedialog
import os
import pandas as pd
import glob
from PIL import Image
import yaml


def select_folder(title):
    """弹出文件夹选择对话框,返回选择的文件夹路径"""
    root = tk.Tk()
    root.withdraw()

    folder_path = filedialog.askdirectory(
        title=title,
        initialdir=os.getcwd()
    )

    if folder_path:
        print(f"选择的文件夹: {folder_path}")
        return folder_path
    else:
        print("未选择文件夹")
        return None


def read_and_merge_csv_files(folder_path):
    """读取文件夹内所有CSV文件,合并成一个数组"""
    if not os.path.exists(folder_path):
        print(f"文件夹不存在: {folder_path}")
        return []

    csv_files = glob.glob(os.path.join(folder_path, "*.csv"))
    if not csv_files:
        print(f"文件夹内没有CSV文件: {folder_path}")
        return []

    merged_data = []

    for csv_file in csv_files:
        try:
            df = pd.read_csv(csv_file)

            # 检查列名是否正确
            required_columns = ['frame', 'x', 'y']
            if not all(col in df.columns for col in required_columns):
                print(f"跳过文件 {csv_file}: 列名不符合要求")
                continue

            # 检查数据类型
            if not pd.api.types.is_integer_dtype(df['frame']):
                print(f"跳过文件 {csv_file}: frame列不是整数类型")
                continue

            if not pd.api.types.is_numeric_dtype(df['x']) or not pd.api.types.is_numeric_dtype(df['y']):
                print(f"跳过文件 {csv_file}: x或y列不是数值类型")
                continue

            # 只保留需要的列
            data = df[required_columns].values.tolist()
            merged_data.extend(data)
            print(f"成功读取文件 {csv_file}: {len(data)} 条记录")

        except Exception as e:
            print(f"读取文件 {csv_file} 时出错: {e}")
            continue
        # 类型转换
        for row in data:
            row[0] = int(row[0])  # frame列转换为int
            row[1] = float(row[1])  # x列转换为float
            row[2] = float(row[2])  # y列转换为float

    print(f"总共合并了 {len(merged_data)} 条记录")
    return merged_data


def list_files_without_extension(folder_path):
    """列出文件夹中所有文件(非递归),去除扩展名、排序去重后返回"""
    if not os.path.exists(folder_path):
        print(f"文件夹不存在: {folder_path}")
        return []

    file_names = []

    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        if os.path.isfile(item_path):
            # 去除扩展名
            name_without_ext = os.path.splitext(item)[0]
            file_names.append(name_without_ext)

    # 排序去重
    unique_sorted_names = sorted(list(set(file_names)))
    return unique_sorted_names


def read_images_size(image_folder):
    """读取图像文件夹中第一张图像的尺寸,返回宽高元组"""

    if not os.path.exists(image_folder):
        print(f"图像文件夹不存在: {image_folder}")
        return (0, 0)

    image_files = glob.glob(os.path.join(image_folder, "*"))
    if not image_files:
        print(f"图像文件夹内没有图像文件: {image_folder}")
        return (0, 0)

    with Image.open(image_files[0]) as img:
        width, height = img.size
        print(f"图像尺寸: 宽={width}, 高={height}")
        return (width, height)


def create_label_files(dataset_path, target_type, label_list, filename_list, image_size=(0, 0)):
    dir_name = os.path.join(dataset_path, "labels", target_type)
    os.makedirs(dir_name, exist_ok=True)
    # 清空目录下的所有文件
    for filename in os.listdir(dir_name):
        file_path = os.path.join(dir_name, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
    # 创建标签文件
    for label in label_list:
        frame, x, y = label
        image_name = filename_list[frame]
        label_file_path = os.path.join(dir_name, f"{image_name}.txt")
        # 将像素坐标转换为YOLO格式的相对坐标
        x_center = x / image_size[0]
        y_center = y / image_size[1]
        # YOLO格式: class_id x_center y_center width height
        with open(label_file_path, "a") as f:
            f.write(f"0 {x_center:.6f} {y_center:.6f} 0.001 0.001\n")


def gen_config(dataset_path):
    config = {
        'train': 'images/train',
        'val': 'images/val',
        'nc': 1,
        'names': ['object']
    }
    config_path = os.path.join(dataset_path, 'data.yaml')
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    print(f"配置文件已生成: {config_path}")


if __name__ == "__main__":
    dataset_path = select_folder("选择数据集文件夹")
    if dataset_path is None:
        exit(1)
    train_csv_dir = select_folder("选择训练集标注CSV文件夹")
    if train_csv_dir is None:
        exit(1)
    val_csv_dir = select_folder("选择验证集标注CSV文件夹")
    if val_csv_dir is None:
        exit(1)

    train_data = read_and_merge_csv_files(train_csv_dir)
    val_data = read_and_merge_csv_files(val_csv_dir)

    train_files = list_files_without_extension(
        os.path.join(dataset_path, "images", "train"))
    val_files = list_files_without_extension(
        os.path.join(dataset_path, "images", "val"))
    print(f"训练集文件数: {len(train_files)}")
    print(f"验证集文件数: {len(val_files)}")

    if (max(train_data, key=lambda x: x[0])[0] >= len(train_files)):
        print("训练集标注数据帧数多于训练集文件数")
        exit(1)
    if (max(val_data, key=lambda x: x[0])[0] >= len(val_files)):
        print("验证集标注数据帧数多于验证集文件数")
        exit(1)

    image_size = read_images_size(
        os.path.join(dataset_path, "images", "train")
    )

    print("数据集准备完成")

    create_label_files(dataset_path, "train", train_data,
                       train_files, image_size)
    create_label_files(dataset_path, "val", val_data, val_files, image_size)
    gen_config(dataset_path)

    print("标签文件生成完成")

经过上面的脚本处理,能够生成符合YOLO格式的数据集。

训练

于是就能愉快开始训练了。

使用下面简单的代码就能开始训练

if __name__ == "__main__":
    from ultralytics import YOLO
    model = YOLO("yolo11s.pt")
    results = model.train(data="dataset/data.yaml", epochs=100,
                          imgsz=640, batch=0.9, project="training_results", exist_ok=True)

(此处,在Windows必须要有 if __name__ == "__main__" 的判断,因为多进程需要)

测试

使用上面的代码训练的话,训练出来的东西位于 training_results/train/weights/best.pt

使用下面的代码,直接调用摄像头进行测试:

from ultralytics import YOLO
import cv2

model = YOLO('training_results/train/weights/best.pt')
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    results = model(frame)
    annotated_frame = results[0].plot()
    cv2.imshow('YOLO Detection', annotated_frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

最终效果

我能想到的,最大的成功就是无愧于自己的心。