加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

centerpose进行关键点检测

(2021-08-26 15:06:03)
分类: 机器学习
本篇介绍图片中关键点检测 以及 视频中关键点检测的两种方法。

1、centerpose是基于centernet开发的,
centernet的github地址是:https://github.com/xingyizhou/CenterNet
在Third-party resources中有如下截图:
centerpose进行关键点检测
点击这个连接,就可以进入centerpose的github

2、centerpose安装说明
(1)安装pytorch和torchvision
torch需要1.1版本的,
torchvision需要0.3.0以上(必须的,版本过低会报错)

(2)安装依赖的python包
pip install -r requirements.txt

(3)编译DCNv2
cd $CenterNet_ROOT/src/lib/models/networks/DCNv2(必须在linux环境进行编译,我试了windows各种坑还是没搞定)
./make.sh

(4)编译nms
cd $CenterNet_ROOT/src/lib/external
make

3、centerpose关键点个数修改训练
以下说明得到的效果不是很好,最后得到的关键点类似于bbox,没有完全拟合标注的关键点。

coco默认的关键点有17个,见:https://zhuanlan.zhihu.com/p/121452714,
https://blog.csdn.net/u011291667/article/details/84329990
centerpose进行关键点检测

在标注时,需要严格按照相应顺序进行标注,不能乱掉
coco的json文件格式如下,其中segmentation、num_keypoints、keypoints是必备的信息。
centerpose进行关键点检测

annotation的一个完整例子如下:
{"category_id": 1, "id": 17, "image_id": 12, "iscrowd": 0, "num_keypoints": 4,"segmentation":[[122.76, 439.67, 122.3, 523.82, 100.23, 523.82, 101.6, 439.7]], "keypoints": [122.76, 439.67, 2, 122.3, 523.82, 2, 100.23, 523.82, 2, 101.6, 439.7, 2], "area": 1839.0, "bbox": [100.0, 440.0, 23.0, 84.0]}

-------------------------------修改关键点标注个数的处理方式-------------------------------------
假如我们希望关键点的个数是4个,而不是17个,怎么处理?
(1)所有代码中17替换成4个
(2)部分五十几的(比如59)替换成十几(比如13),下面是coco_hp.py文件
centerpose进行关键点检测
dcts的前5位都是bbox相关的,第5位之后才能keypoints的数据,总共13位。

(3)coco_hp.py的run_eval的coco_eval = COCOeval_r(self.coco, coco_dets, "keypoints")替换为coco_eval = COCOeval_r(self.coco, coco_dets, "bbox")。
因为依赖的底层文件cocoeval.py的setKpParams的kpt_oks_sigmas是固定17位的,所以使用keypoints是绕不开这个文件的。
涉及逻辑就是:如果计算得到的iou的效果更好,则更新为best_model。替换之后,就根据bbox的效果评估是否要替换。

(4)detectors/multi_pose.py
最后一个函数show_results()中,debugger.show_all_imgs(pause=self.pause) 修改为: debugger.save_all_imgs(path='/workspace/hugh/centerpose-master/output')。
img_id都修改为"multi_pose",如下图所示:
centerpose进行关键点检测


(5)utils/debugger.py
对debugger.add_coco_hp函数进行修改,通过cv2.line让4个keypoints连接起来,并对heatmap的中心点画一个小圆圈。
centerpose进行关键点检测


4、训练相关命令:
训练:
python -m torch.distributed.launch train.py --cfg ../experiments/dla_34_512x512.yaml
测试:
python demo.py --cfg /workspace/hugh/centerpose-master/experiments/dla_34_512x512.yaml --TESTMODEL /workspace/hugh/data/centerpose_output/dla/model_best.pth --DEMOFILE /workspace/hugh/data/centerpose_coco/images/16.jpg --DEBUG 1

从实践来看,dla_34的算法比mobilenetv3的效果更好。

5、dla_34_512x512.yaml配置参考

SAMPLE_METHOD: 'coco_hp'
DATA_DIR: '/workspace/hugh/data'
EXP_ID: 'coco_pose_dla'
DEBUG: 0
DEBUG_THEME: 'white'
SEED: 317
OUTPUT_DIR: '/workspace/hugh/data/centerpose_output/dla'
LOG_DIR: ''
EXPERIMENT_NAME: ''
GPUS: [1]
WORKERS: 1
PRINT_FREQ: 0
PIN_MEMORY: true
RANK: 0
SAVE_RESULTS: true

CUDNN:
  BENCHMARK: true

MODEL:
  INIT_WEIGHTS: false
  PRETRAINED: ''
  CENTER_THRESH: 0.1
  NUM_CLASSES: 1
  NAME: 'dla_34'
  HEADS_NAME: 'keypoint'
  HEADS_NUM: [1, 2, 8, 2, 4, 2]
  HEAD_CONV: 256
  INTERMEDIATE_CHANNEL: 64  
  DOWN_RATIO: 4
  NUM_STACKS: 1
  INPUT_RES: 512
  OUTPUT_RES: 128
  INPUT_H: 512
  INPUT_W: 512
  PAD: 31
  NUM_KEYPOINTS: 17
  TAG_PER_JOINT: true
  TARGET_TYPE: 'gaussian'
  SIGMA: 2    

LOSS:
  METRIC: 'loss'
  MSE_LOSS: false
  REG_LOSS: 'l1'
  USE_OHKM: false
  TOPK: 8
  USE_TARGET_WEIGHT: true
  USE_DIFFERENT_JOINTS_WEIGHT: false
  HP_WEIGHT: 1.
  HM_HP_WEIGHT: 1.
  DENSE_HP: false
  HM_HP: true
  REG_BBOX: true
  WH_WEIGHT: 0.1
  REG_OFFSET: true  
  OFF_WEIGHT: 1.
  REG_HP_OFFSET: true
  HM_HP_WEIGHT: 1.
  
DATASET:
  DATASET: 'coco_hp'
  TRAIN_SET: 'train'
  TEST_SET: 'valid'
  TRAIN_IMAGE_DIR: 'images'
  TRAIN_ANNOTATIONS: ['train.json']
  VAL_IMAGE_DIR: 'images'
  VAL_ANNOTATIONS: 'train.json'

  # training data augmentation
  MEAN: [0.408, 0.447, 0.470]
  STD: [0.289, 0.274, 0.278]
  SHIFT: 0.1
  SCALE: 0.4
  ROTATE: 0.
  # for pose
  AUG_ROT: 0.
  FLIP: 0.5
  NO_COLOR_AUG: false

  ROT_FACTOR: 30
  SCALE_MIN: 0.5
  SCALE_MAX: 1.1
  IMAGE_SIZE: 512
  RANDOM_CROP: true 
  
TRAIN:
  DISTRIBUTE: true
  OPTIMIZER: 'adam'
  LOCAL_RANK: 0
  HIDE_DATA_TIME: false 
  SAVE_ALL_MODEL: false
  RESUME: false
  LR_FACTOR: 0.1
  LR_STEP: [270, 300]
  EPOCHS: 32000
  NUM_ITERS: -1
  LR: 2.8125e-3
  BATCH_SIZE: 30
  MASTER_BATCH_SIZE: 30

  MOMENTUM: 0.9
  WD: 0.0001
  NESTEROV: false
  GAMMA1: 0.99
  GAMMA2: 0.0

  # 'apply and reset gradients every n batches'
  STRIDE_APPLY: 1
  CHECKPOINT: ''
  SHUFFLE: true
  VAL_INTERVALS: 1
  TRAINVAL: false
 
TEST:
  # Test Model Epoch
  MODEL_PATH: '/workspace/hugh/data/centerpose_output/dla/model_best.pth'
  TASK: 'multi_pose'
  FLIP_TEST: true

  DEMO_FILE: '../images/16.jpg'
  MODEL_FILE: ''
  TEST_SCALES: [1]
  IMAGE_THRE: 0.1
  TOPK: 100
  NMS: true
  NMS_THRE: 0.5
  NOT_PREFETCH_TEST: false
  FIX_RES: false

  SOFT_NMS: false
  OKS_THRE: 0.5
  VIS_THRESH: 0.3
  KEYPOINT_THRESH: 0.2
  NUM_MIN_KPT: 4
  THRESH_HUMAN: 0.4

  EVAL_ORACLE_HM: false
  EVAL_ORACLE_WH: false
  EVAL_ORACLE_OFFSET: false
  EVAL_ORACLE_KPS: false
  EVAL_ORACLE_HMHP: false
  EVAL_ORACLE_HP_OFFSET: false
  EVAL_ORACLE_DEP: false


--------------------------视频中进行关键点检测------------------------------------------
1、视频中进行关键点检测的流程:
(1)读取视频流中每一帧图片
(2)对图片识别目标对象
(3)将识别后的图像重新输出到视频流中

接下来的修改是基于上面的图片中关键点检测  做进一步修改
2、detectors/multi_pose.py
show_results函数,注释掉"debugger.save_all_imgs(path='/workspace/hugh/centerpose-master/output')"

3、detectors/base_detector.py
修改run()中的函数如下,增加multi_pose相关信息的输出,有目标信息的图像就是放在debugger.imgs['multi_pose']中。
centerpose进行关键点检测

4、tools/demo.py修改
增加视频流读取后的处理操作,并输出到视频流中。
同时创建好results目录,如下面代码所示的路径。
centerpose进行关键点检测



5、视频流进行关键点检测的执行命令:
python demo.py --cfg /workspace/hugh/centerpose-master/experiments/dla_34_512x512.yaml --TESTMODEL /workspace/hugh/data/centerpose_output/dla/model_best.pth --DEMOFILE /workspace/hugh/data/video/test_track.avi --DEBUG 1

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有