百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

使用ImageNet在faster-rcnn上训练自己的分类器

toyiye 2024-06-30 10:04 21 浏览 0 评论

这是我对cup, glasses训练的识别

faster-rcnn在fast-rcnn的基础上加了rpn来将整个训练都置于GPU内,以用来提高效率,这里我们将使用ImageNet的数据集来在faster-rcnn上来训练自己的分类器。从ImageNet上可下载到很多类别的Image与bounding box annotation来进行训练(每一个类别下的annotation都少于等于image的个数,所以我们从annotation来建立索引)。

lib/dataset/factory.py中提供了coco与voc的数据集获取方法,而我们要做的就是在这里加上我们自己的ImageNet获取方法,我们先来建立ImageNet数据获取主文件。coco与pascal_voc的获取都是继承于父类imdb,所以我们可根据pascal_voc的获取方法来做模板修改完成我们的ImageNet类。

创建ImageNet类

由于在faster-rcnn里使用rpn来代替了selective_search,所以我们可以在使用时直接略过有关selective_search的方法,根据pascal_voc类做模板,我们需要留下的方法有:

__init__ //初始化
image_path_at //根据数据集列表的index来取图片绝对地址
image_path_from_index //配合上面
_load_image_set_index //获取数据集列表
_gt_roidb //获取ground-truth数据
rpn_roidb //获取region proposal数据
_load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
_load_psacal_annotation //加载annotation文件并对bounding box进行数据整理

__init__:

def __init__(self, image_set):
        imdb.__init__(self, 'imagenet')
        self._image_set = image_set
        self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
        #类别与对应的wnid,可以修改成自己要训练的类别
        self._class_wnids = {
 'cup': 'n03147509',
 'glasses': 'n04272054'
        }

        #类别,修改类别时同时要修改这里
        self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        #bounding box annotation 文件的目录
        self._xml_path = os.path.join(self._data_path, "Annotations")
        self._image_ext = '.JPEG'
        #我们使用xml文件名来做数据集的索引
        # the xml file name and each one corresponding to image file name
        self._image_index = self._load_xml_filenames
        self._salt = str(uuid.uuid4)
        self._comp_id = 'comp4'

        self.config = {'cleanup'     : True,
 'use_salt'    : True,
 'use_diff'    : False,
 'matlab_eval' : False,
 'rpn_file'    : None,
 'min_size'    : 2}

        assert os.path.exists(self._data_path), \
 'Path does not exist: {}'.format(self._data_path)

image_path_at

def image_path_at(self, i):
        #使用index来从xml_filenames取到filename,生成绝对路径
        return self.image_path_from_image_filename(self._image_index[i])

image_path_from_image_filename(类似pascal_voc中的image_path_from_index)

def image_path_from_image_filename(self, image_filename):
        image_path = os.path.join(self._data_path, 'Images',
 image_filename + self._image_ext)
        assert os.path.exists(image_path), \
 'Path does not exist: {}'.format(image_path)
        return image_path

_load_xml_filenames(类似pascal_voc中的_load_image_set_index)

def _load_xml_filenames(self):
        #从Annotations文件夹中拿取到bounding box annotation文件名
        #用来做数据集的索引
        xml_folder_path = os.path.join(self._data_path, "Annotations")
        assert os.path.exists(xml_folder_path), \
 'Path does not exist: {}'.format(xml_folder_path)

        for dirpath, dirnames, filenames in os.walk(xml_folder_path):
 xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames]

        return xml_filenames

gt_roidb

def gt_roidb(self):
        #Ground-Truth 数据缓存
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
 with open(cache_file, 'rb') as fid:
 roidb = cPickle.load(fid)
 print '{} gt roidb loaded from {}'.format(self.name, cache_file)
 return roidb

        #从xml中获取Ground-Truth数据
        gt_roidb = [self._load_imagenet_annotation(xml_filename)
 for xml_filename in self._image_index]
        with open(cache_file, 'wb') as fid:
 cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb

rpn_roidb

def rpn_roidb(self):
        #根据gt_roidb生成rpn_roidb,并进行合并 
        gt_roidb = self.gt_roidb
        rpn_roidb = self._load_rpn_roidb(gt_roidb)
        roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)

        return roidb

_load_rpn_roidb

def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print 'loading {}'.format(filename)
        assert os.path.exists(filename), \
 'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
 box_list = cPickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

_load_imagenet_annotation(类似于pascal_voc中的_load_pascal_annotation)

def _load_imagenet_annotation(self, xml_filename):
        #从annotation的xml文件中拿取bounding box数据
        filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
        #这里使用了ap,是我写的一个annotation parser,在后面贴出代码
        #它会返回这个xml文件的wnid, 图像文件名,以及里面包含的注解物体
        wnid, image_name, objects = ap.parse(filepath)
        num_objs = len(objects)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objects):
 box = obj["box"]
 x1 = box['xmin']
 y1 = box['ymin']
 x2 = box['xmax']
 y2 = box['ymax']
 # 如果这个bounding box并不是我们想要学习的类别,那则跳过
 # go next if the wnid not exist in declared classes
 try:
 cls = self._class_to_ind[obj["wnid"]]
 except KeyError:
 print "wnid %s isn't show in given"%obj["wnid"]
 continue
 boxes[ix, :] = [x1, y1, x2, y2]
 gt_classes[ix] = cls
 overlaps[ix, cls] = 1.0
 seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
 'gt_classes': gt_classes,
 'gt_overlaps' : overlaps,
 'flipped' : False,
 'seg_areas' : seg_areas}

annotation_parser.py文件

import os
import xml.dom.minidom

def getText(node):
    return node.firstChild.nodeValue

def getWnid(node):
    return getText(node.getElementsByTagName("name")[0])

def getImageName(node):
    return getText(node.getElementsByTagName("filename")[0])

def getObjects(node):
    objects = 
    for obj in node.getElementsByTagName("object"):
        objects.append({
 "wnid": getText(obj.getElementsByTagName("name")[0]),
 "box":{
 "xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
 "ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
 "xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
 "ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
 }
        })
    return objects

def parse(filepath):
    dom = xml.dom.minidom.parse(filepath)
    root = dom.documentElement
    image_name = getImageName(root)
    wnid = getWnid(root)
    objects = getObjects(root)
    
    return wnid, image_name, objects

则对数据结构的要求是:

|---data
  |---imagenet
    |---Annotations
       |---n03147509
 |---n03147509_*.xml
 |---...
       |---n04272054
 |---n04272054_*.xml
 |---...
    |---Images
       |---n03147508_*.JPEG
       |---...
       |---n04272054_*.JPEG
       |---...

同时我在github上也提供了draw方法,可以用来将bounding box画于Image文件上,用来甄别该annotation的正确性

训练

这样,我们的ImageNet类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改prototxt中的与类别数目有关的值,我将models/pascal_voc拷贝到了models/imagenet进行修改,比如我想要训练ZF,如果使用的是train_faster_rcnn_alt_opt.py,则需要修改models/imagenet/ZF/faster_rcnn_alt_opt/下的所有pt文件里的内容,用如下的法则去替换:

//num为类别的个数
input-data->num_classes = num
class_score->num_output = num
bbox_pred->num_output   = num*4

我这里使用train_faster_rcnn_alt_opt.py进行的训练,这样的话则需要把添加的models/imagenet作为可选项

//pt_type 则是添加的选择项,默认使用psacal_voc的models
./tools/train_faster_rcnn_alt_opt.py --gpu 0 \
--net_name ZF \
--weights data/imagenet_models/ZF.v2.caffemodel[optional] \
--imdb imagenet \
--cfg experiments/cfgs/faster_rcnn_alt_opt.yml \
--pt_type imagenet

识别

这里我们则需要使用刚训练出来的模型进行识别

#就像demo.py一样,但是使用训练的models,我创建了tools/classify.py来单独识别
prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')

同样,在识别前我们要对识别方法里的Classes进行修改,修改成你自己训练的类别后

执行

./tools/classify.py --net zf

则可对data/demo下的图片文件使用训练的zf网络进行识别

Have fun

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码