AP, mAP and other metrics

clear;
box_path = "/home/koo/HPC/mmbox/code/fourth_version/foggy_box/";
list_path = dir(box_path);
list_path = list_path(3:end);
iou_threshold = 0.5;
result_0 = [];
result_1 = [];
count_gt0 = 0;
count_gt1 = 0;
% total number of frame
total_num = size(list_path, 1);

% accumulate depth loss
depth_loss0 = 0;
depth_loss1 = 0;
depth_loss0_collection = [];
depth_loss1_collection = [];
% center point distance
dis_num0 = 0;
dis_num1 = 0;
dis_num0_collection = [];
dis_num1_collection = [];
% width ratio
w_ratio0 = 0;
w_ratio1 = 0;
% height ratio
h_ratio0 = 0;
h_ratio1 = 0;
h_ratio0_collection = [];
h_ratio1_collection = [];
w_ratio0_collection = [];
w_ratio1_collection = [];
for path = 1:size(list_path,1)
    name = strcat(box_path,list_path(path).name);
    data = load(name);
    prediction = data.prediction;
    gt = data.detection;
    gt = reshape(gt, [20,6]);
    % mark whether this gt box is matched
    flag = zeros(size(gt,1), 1);
    % accumulate the total number per class in ground truth
    for i = 1:size(gt,1)
        if(gt(i,5)~= 0)
            if(gt(i,6)==0)
                count_gt0 = count_gt0 + 1;
            else
                count_gt1 = count_gt1 + 1;
            end
        else
            break;
        end
    end
    for i = 1:size(prediction,1)
        pre_one = prediction(i, :);
        y1 = pre_one(1);
        x1 = pre_one(2);
        y2 = pre_one(3);
        x2 = pre_one(4);
        conf = pre_one(5);
        depth = pre_one(6) * 20;
        c_conf = pre_one(7);
        c_pred = pre_one(8);

        tp = 0;
        fp = 0;
        index_t = 0;
        max_iou = 0;
        for j = 1:size(gt, 1)
            % not null
            if(gt(j, 5) == 0)
                break;
            end
            gt_c = gt(j, 6);
            % class is correct && not matched
            if((gt_c == c_pred) && (flag(j) == 0))
                y1_gt = (gt(j,1) - gt(j,3)/2) * 1080;
                x1_gt = (gt(j,2) - gt(j,4)/2) * 1920;
                y2_gt = (gt(j,1) + gt(j,3)/2) * 1080;
                x2_gt = (gt(j,2) + gt(j,4)/2) * 1920;
                gt_depth = gt(j, 5) * 20;
                % small box
                b1_x1 = max(x1, x1_gt);
                b1_y1 = max(y1, y1_gt);
                b1_x2 = min(x2, x2_gt);
                b1_y2 = min(y2, y2_gt);

                if(b1_y2 < b1_y1 || b1_x2 < b1_x1)
                    continue;
                end

                % big box
                b2_x1 = min(x1, x1_gt);
                b2_y1 = min(y1, y1_gt);
                b2_x2 = max(x2, x2_gt);
                b2_y2 = max(y2, y2_gt);
                
                overlapped_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1);
                union_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1);
                iou = overlapped_area / union_area;
                % if iou is more than threshold, tp + 1;
                if(iou >= iou_threshold && iou > max_iou)
                    %best matched
                    tp = 1;
                    index_t = j;
                    max_iou = iou;
                end

            else
                continue;
            end
        end
        % if not matched, change it to fp
        if tp == 0 
            fp = 1;
        else
            flag(index_t) = 1;
        end
        % record to vehicle or pedestrian
        if(c_pred == 0)
            result_0 = [result_0; conf, tp, fp];
            if(tp == 1)
                % record depth loss
                depth_loss0_collection = [depth_loss0_collection ,gt(index_t, 5) * 20 - depth];
%                 depth_loss0 = depth_loss0 + abs(gt(index_t, 5) * 20 - depth);
                % center point distance
                y1_gt = (gt(index_t,1) - gt(index_t,3)/2) * 1080;
                x1_gt = (gt(index_t,2) - gt(index_t,4)/2) * 1920;
                y2_gt = (gt(index_t,1) + gt(index_t,3)/2) * 1080;
                x2_gt = (gt(index_t,2) + gt(index_t,4)/2) * 1920;
                cp_x = (x2 + x1) / 2;
                cp_y = (y2 + y1) / 2;
                gt_cp_x = (x2_gt + x1_gt) / 2;
                gt_cp_y = (y2_gt + y1_gt) / 2;
    
%                 dis_num0 = dis_num0 + sqrt((gt_cp_y - cp_y)^2 + (gt_cp_x - cp_x)^2);
                dis_num0_collection = [dis_num0_collection, sqrt((gt_cp_y - cp_y)^2 + (gt_cp_x - cp_x)^2)];
                % record width and height ratio
                w = x2 - x1;
                h = y2 - y1;
                w_gt = x2_gt - x1_gt;
                h_gt = y2_gt - y1_gt;
                w_ratio0_collection = [w_ratio0_collection, w/w_gt];
%                 w_ratio0 = w_ratio0 + abs(w/w_gt-1);
                h_ratio0_collection = [h_ratio0_collection, h/h_gt];
%                 h_ratio0 = h_ratio0 + abs(h/h_gt-1);

            end
        else
            result_1 = [result_1; conf, tp ,fp];
            if(tp == 1)
                % record depth loss
%                 depth_loss1 = depth_loss1 + abs(gt(index_t, 5) * 20  - depth);
                depth_loss1_collection = [depth_loss1_collection ,gt(index_t, 5) * 20 - depth];
                % center point distance
                y1_gt = (gt(index_t,1) - gt(index_t,3)/2) * 1080;
                x1_gt = (gt(index_t,2) - gt(index_t,4)/2) * 1920;
                y2_gt = (gt(index_t,1) + gt(index_t,3)/2) * 1080;
                x2_gt = (gt(index_t,2) + gt(index_t,4)/2) * 1920;
                cp_x = (x2 + x1) / 2;
                cp_y = (y2 + y1) / 2;
                gt_cp_x = (x2_gt + x1_gt) / 2;
                gt_cp_y = (y2_gt + y1_gt) / 2;
    
%                 dis_num1 = dis_num1 + sqrt((gt_cp_y - cp_y)^2 + (gt_cp_x - cp_x)^2);
                dis_num1_collection = [dis_num1_collection, sqrt((gt_cp_y - cp_y)^2 + (gt_cp_x - cp_x)^2)];
                % record width and height ratio
                w = x2 - x1;
                h = y2 - y1;
                w_gt = x2_gt - x1_gt;
                h_gt = y2_gt - y1_gt;
                w_ratio1_collection = [w_ratio1_collection, w/w_gt];
%                 w_ratio1 = w_ratio1 + abs(w/w_gt-1);
                h_ratio1_collection = [h_ratio1_collection, h/h_gt];
%                 h_ratio1 = h_ratio1 + abs(h/h_gt-1);
            end
        end
    end
end
% matched total number
match_num0 = sum(result_0(:,2));
match_num1 = sum(result_1(:,2));

miss_num_0 = count_gt0 - match_num0;
miss_num_1 = count_gt1 - match_num1;
miss_rate0 = miss_num_0 / total_num;
miss_rate1 = miss_num_1 / total_num;

disp("missed vehicle and pedestrian per frame: ")
disp(miss_rate0)
disp(miss_rate1)

% miss rate per object
miss_rate_obj0 = miss_num_0 / count_gt0;
miss_rate_obj1 = miss_num_1 / count_gt1;
disp("missed vehicle and pedestrian per object: ")
disp(miss_rate_obj0)
disp(miss_rate_obj1)

% average depth loss
depth_loss0 = sum(abs(depth_loss0_collection));
aver_depth_loss0 = depth_loss0 / match_num0;
depth_loss1 = sum(abs(depth_loss1_collection));
aver_depth_loss1 = depth_loss1 / match_num1;

depth_loss0_collection = sort(abs(depth_loss0_collection));
cdf = cdfplot(depth_loss0_collection);
cdf_depth_loss0 = [cdf.XData', cdf.YData'];
cdf_depth_loss0 = cdf_depth_loss0(2:end-1,:);


depth_loss1_collection = sort(abs(depth_loss1_collection));
cdf = cdfplot(depth_loss1_collection);
cdf_depth_loss1 = [cdf.XData', cdf.YData'];
cdf_depth_loss1 = cdf_depth_loss1(2:end-1,:);


disp("average depth loss for vehicle and pedestrian")
disp(aver_depth_loss0)
disp(aver_depth_loss1)

% average center point distance
dis_num0 = sum(dis_num0_collection);
aver_dis0 = dis_num0 / match_num0;
dis_num1 = sum(dis_num1_collection);
aver_dis1 = dis_num1 / match_num1;
disp("average center point distance in pixels: ")
disp(aver_dis0)
disp(aver_dis1)

dis_num0_collection = sort(dis_num0_collection);
cdf = cdfplot(dis_num0_collection);
cdf_dis_num0 = [cdf.XData', cdf.YData'];
cdf_dis_num0 = cdf_dis_num0(2:end-1,:);


dis_num1_collection = sort(dis_num1_collection);
cdf = cdfplot(dis_num1_collection);
cdf_dis_num1 = [cdf.XData', cdf.YData'];
cdf_dis_num1 = cdf_dis_num1(2:end-1,:);

% average width and height ratio
w_ratio0_collection = sort(abs(w_ratio0_collection - 1));
w_ratio0 = sum(w_ratio0_collection);
aver_w_ratio0 = w_ratio0 / match_num0;
w_ratio1_collection = sort(abs(w_ratio1_collection - 1));
w_ratio1 = sum(w_ratio1_collection);
aver_w_ratio1 = w_ratio1 / match_num1;

h_ratio0_collection = sort(abs(h_ratio0_collection - 1));
h_ratio0 = sum(h_ratio0_collection);
aver_h_ratio0 = h_ratio0 / match_num0;
h_ratio1_collection = sort(abs(h_ratio1_collection - 1));
h_ratio1 = sum(h_ratio1_collection);
aver_h_ratio1 = h_ratio1 / match_num1;
disp("average width ratio for vehicle and pedestrian:")
disp(aver_w_ratio0)
disp(aver_w_ratio1)
disp("average height ratio for vehicle and pedestrian: ")
disp(aver_h_ratio0)
disp(aver_h_ratio1)

cdf = cdfplot(w_ratio0_collection);
cdf_w_ratio0 = [cdf.XData', cdf.YData'];
cdf_w_ratio0 = cdf_w_ratio0(2:end-1,:);

cdf = cdfplot(w_ratio1_collection);
cdf_w_ratio1 = [cdf.XData', cdf.YData'];
cdf_w_ratio1 = cdf_w_ratio1(2:end-1,:);

cdf = cdfplot(h_ratio0_collection);
cdf_h_ratio0 = [cdf.XData', cdf.YData'];
cdf_h_ratio0 = cdf_h_ratio0(2:end-1,:);

cdf = cdfplot(h_ratio1_collection);
cdf_h_ratio1 = [cdf.XData', cdf.YData'];
cdf_h_ratio1 = cdf_h_ratio1(2:end-1,:);

% sort by confidence
[sorted_result0, sorted_index0] = sort(result_0(:,1), 1, 'descend');
[sorted_result1, sorted_index1] = sort(result_1(:,1), 1, 'descend');
sorted_result0 = result_0(sorted_index0, :);
sorted_result1 = result_1(sorted_index1, :);
PR0 = zeros(size(sorted_result0, 1), 4);
PR1 = zeros(size(sorted_result1, 1), 4);
tp_count = 0;
fp_count = 0;
for i = 1:size(sorted_result0,1)
    if(sorted_result0(i, 2) == 1)
        tp_count = tp_count + 1;
    else
        fp_count = fp_count + 1;
    end
    PR0(i,:) = [tp_count, fp_count, tp_count/(tp_count+fp_count), tp_count/count_gt0];
end

tp_count = 0;
fp_count = 0;
for i = 1:size(sorted_result1,1)
    if(sorted_result1(i, 2) == 1)
        tp_count = tp_count + 1;
    else
        fp_count = fp_count + 1;
    end
    PR1(i,:) = [tp_count, fp_count, tp_count/(tp_count+fp_count), tp_count/count_gt1];
end
figure(1);
title("Precision Recall curve");
subplot(1,2,1);
plot(PR0(:,4), PR0(:,3));
subplot(1,2,2);
plot(PR1(:,4), PR1(:,3));

mrec = [0; PR0(:,4);1];
mpre = [0; PR0(:, 3);0];
for i=numel(mpre)-1:-1:1
    mpre(i)=max(mpre(i),mpre(i+1));
end
i=find(mrec(2:end)~=mrec(1:end-1))+1;
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));


mrec = [0; PR1(:,4);1];
mpre = [0; PR1(:, 3);0];
for i=numel(mpre)-1:-1:1
    mpre(i)=max(mpre(i),mpre(i+1));
end
i=find(mrec(2:end)~=mrec(1:end-1))+1;
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
figure(100);
plot(mrec,mpre);

% copy PR
PR0_all = PR0;
PR1_all = PR1;

%all intropolated AP
temp_value = 0;
for i = size(PR0_all,1):-1:1
    if(PR0_all(i,3)>temp_value)
        temp_value = PR0_all(i,3);
    else
        PR0_all(i,3) = temp_value;
    end
end
temp_value = 0;
for i = size(PR1_all,1):-1:1
    if(PR1_all(i,3)>temp_value)
        temp_value = PR1_all(i,3);
    else
        PR1_all(i,3) = temp_value;
    end
end

figure(2);
title("All interpolated PR curve");
subplot(1,2,1);
plot(PR0_all(:,4), PR0_all(:,3));
subplot(1,2,2);
plot(PR1_all(:,4), PR1_all(:,3));

%AP
temp_value = 0;
for i = size(PR0_all,1)-1:-1:1
    temp_value = temp_value+(PR0_all(i+1,4)-PR0_all(i,4))*PR0_all(i,3);
end
AP0_all = temp_value;
temp_value = 0;
for i = size(PR1_all,1)-1:-1:1
    temp_value = temp_value+(PR1_all(i+1,4)-PR1_all(i,4))*PR1_all(i,3);
end
AP1_all = temp_value;

disp("AP0_all = ")
disp(AP0_all)
disp("AP1_all = ")
disp(AP1_all)


% 11-intropolated
PR0_11 = PR0;
PR1_11 = PR1;


stride = linspace(0,1,11);
result_0 = [];
for i = 11:-1:1
    index = find(PR0_11(:,4) > stride(i));
    if(isempty(index))
        result_0 = [0, result_0];
    else
        result_0 = [max(PR0_11(index,3)), result_0];
    end
end
stride = linspace(0,1,11);
result_1 = [];
for i = 11:-1:1
    index = find(PR1_11(:,4) > stride(i));
    if(isempty(index))
        result_1 = [0, result_1];
    else
        result_1 = [max(PR1_11(index,3)), result_1];
    end
end
figure(3);
title("11 interpolated PR curve");
subplot(1,2,1);
plot(stride, result_0);
subplot(1,2,2);
plot(stride, result_1);
%AP
AP0_11 = sum(result_0)/11;
AP1_11 = sum(result_1)/11;
disp("AP0_11 = ")
disp(AP0_11)
disp("AP1_11 = ")
disp(AP1_11)



    

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top