实现类:CenterPoint/det3d/core/sampler/sample_ops.py: DataBaseSamplerV2
build_dbsampler参数:
type="GT-AUG",enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
dict(construction_vehicle=7),
dict(bus=4),
dict(trailer=6),
dict(barrier=2),
dict(motorcycle=6),
dict(bicycle=6),
dict(pedestrian=2),
dict(traffic_cone=2),
],
db_prep_steps=[
dict(
filter_by_min_num_points=dict(
car=5,
truck=5,
bus=5,
trailer=5,
construction_vehicle=5,
traffic_cone=5,
barrier=5,
motorcycle=5,
bicycle=5,
pedestrian=5,
)
),
dict(filter_by_difficulty=[-1],),
],
global_random_rotation_range_per_object=[0, 0],
rate=1.0,
对应DataBaseSamplerV2构造函数
db_infos = pickle.load(cfg.db_info_path)
groups = cfg.sample_groups
db_prepor = [build_db_preprocess(c, logger=logger) for c in cfg.db_prep_steps]
rate = cfg.rate
grot_range = list(grot_range)
self._group_db_infos = db_infos
self._rate = rate
self._groups = groups
self._sample_classes = ['car', 'truck',...,'traffic_cone']
self._sample_max_nums = [2, 3,...,2]
for k, v in self._group_db_infos.items():
self._sampler_dict[k] = prep.BatchSampler(v, k) # 此处v为单个类别的对象的db数据,k为类名,BatchSampler为一个用于批量Shuffle采样的类
self._global_rot_range = global_rot_range
self._enable_global_rot = True if np.abs(global_rot_range[0] - global_rot_range[1]) >= 1e-3 # 所以没有允许这个,为啥?
具体采样实现为sample_all,实际用到的参数有root_path, gt_boxes, gt_names, num_point_features
Pipeline
1st: 确定每个cls的采样数目:
for class_name, max_sample_num in zip(
self._sample_classes, self._sample_max_nums
):
sampled_num = int(
max_sample_num - np.sum([n == class_name for n in gt_names])
)
sampled_num = np.round(self._rate * sampled_num).astype(np.int64)
sampled_num_dict[class_name] = sampled_num
sample_num_per_class.append(sampled_num)
2nd: 根据确定的采样数目进行采样,调用sample_class_v2
sampled记录采样信息,sampled_gt_boxes记录采样到的bbox,avoid_coll_boxes是已经采样到的对象的bbox和gt的bbox的累加
3rd: 从sampled中取出sample的点云和bbox,并返回相应的res
Sample_class_v2
Pipeline:
1、计算gt的bv box
2、计算采样得到的sp_boxes
3、获得valid_mask Array[bool: len(gt_boxes) + len(sp_boxes)]
4、如果允许global_rot,prep.noise_per_object_v3_ # place samples to any place in a circle
5、计算采样的bv box
6、根据总体的bv box,进行碰撞测试
7、返回无碰撞的对象