CLRNet icon indicating copy to clipboard operation
CLRNet copied to clipboard

nms torch implementation problem

Open khan-yin opened this issue 2 years ago • 1 comments

hello author, I want to implement NMS code in Pytorch, I checked the cuda code and wrote devIoU and NMS process,but the keep index is not right in some imgs,maybe my understanding has some mistake, can you tell me the details and mistake in my NMS?Thanks a Lot! 🤣 here is a true prediction example:

import torch
import numpy as np
from clrnet.ops import nms

def devIoU_torch(a,b):
    """
    shape: conf,y,x,lenght,72offsets
    Compute distance between two lanes.
    """
    DATASET_OFFSET = 0
    n_strips = 71
    n_offsets = 72
    start_a = int(a[2] * n_strips - DATASET_OFFSET + 0.5)
    start_b = int(b[2] * n_strips - DATASET_OFFSET + 0.5)
    start = max(start_a, start_b)
    end_a = int(start_a + a[4] - 1 + 0.5 - int((a[4] - 1) < 0)) # - (x<0) trick to adjust for negative numbers (in case length is 0)
    end_b = int(start_b + b[4] - 1 + 0.5 - int((b[4] - 1) < 0))
    end = min(min(end_a, end_b), n_offsets - 1)

    if end < start:
        return torch.Tensor(1e9)
        
    dist = torch.abs(a[5+start:5+end+1] - b[5+start:5+end+1]).sum()
    return dist / (end - start + 1)

def lane_nms_torch(predictions, scores, nms_overlap_thresh, top_k):
    """
    NMS for lane detection.
    predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] 
    scores: paddle.Tensor [num_lanes]
    nms_overlap_thresh: float
    top_k: int
    """
    # sort by scores to get idx
    idx = scores.argsort(descending=True)
    print(idx)
    keep = []

    condidates = predictions.clone()
    condidates[idx]

    while len(condidates) > 0:
        keep.append(idx[0])
        if len(keep) >= top_k or len(condidates) == 1:
            break
        ious = []

        for i in range(1, len(condidates)):
            ious.append(devIoU_torch(condidates[0], condidates[i]))

        ious = torch.Tensor(ious)
        mask = ious < nms_overlap_thresh
        id = torch.where(mask==False)
        condidates = condidates[1:][id]
        idx = idx[1:][id]
    keep = torch.stack(keep)
    return keep


if __name__ == "__main__":
   # pred [12,77]
    pred = np.array([
        [-0.31868595 ,  0.31770626 ,  0.38531917 ,  0.00519343 ,  41.27337646,
         -231.61357117, -223.09695435, -214.59329224, -206.01312256, -197.44003296,
         -188.83171082, -180.23338318, -171.65122986, -163.03523254, -154.41949463,
         -145.80096436, -137.16973877, -128.51414490, -119.85771942, -111.19036865,
         -102.58344269, -93.94864655, -85.31758118, -76.67832184, -68.02027130,
         -59.38843155, -50.73698425, -42.11750412, -33.49232101, -24.86630058,
         -16.24045563, -7.58664083 ,  1.03324342 ,  9.67142010 ,  18.30059242,
          26.90684891,  35.56851578,  44.20260620,  52.84139633,  61.49933624,
          70.13185883,  78.74620819,  87.39436340,  96.04907227, 104.69485474,
         113.31373596, 121.96540833, 130.64146423, 139.33471680, 147.97837830,
         156.66036987, 165.34426880, 174.01264954, 182.65979004, 191.32659912,
         200.04383850, 208.78700256, 217.48742676, 226.16868591, 234.86297607,
         243.56433105, 252.28915405, 260.99911499, 269.74468994, 278.47253418,
         287.18377686, 295.90975952, 304.67031860, 313.44335938, 322.14849854,
         330.87860107, 339.55444336, 348.23510742, 356.94555664, 365.43688965,
         374.17858887, 381.64529419],
        [-0.75002342 ,  0.74795848 ,  0.38502458 ,  0.00486702 ,  41.25712585,
         -230.89216614, -222.40429688, -213.93872070, -205.38790894, -196.83256531,
         -188.24029541, -179.65783691, -171.09535217, -162.50091553, -153.90159607,
         -145.30148315, -136.68898010, -128.05361938, -119.41861725, -110.76705933,
         -102.17977142, -93.56396484, -84.95408630, -76.33931732, -67.70516968,
         -59.09665680, -50.46805573, -41.87142944, -33.27015305, -24.66850281,
         -16.06601334, -7.43314743 ,  1.16238201 ,  9.77579880 ,  18.37902641,
          26.96254158,  35.59840012,  44.20528793,  52.81388474,  61.44587708,
          70.05379486,  78.63992310,  87.26088715,  95.88831329, 104.50799561,
         113.10137177, 121.72947693, 130.37780762, 139.05093384, 147.67424011,
         156.34182739, 165.00785828, 173.65528870, 182.28471375, 190.93435669,
         199.63537598, 208.36180115, 217.04537964, 225.71020508, 234.38877869,
         243.07611084, 251.78483582, 260.48208618, 269.21682739, 277.93188477,
         286.62890625, 295.34487915, 304.09490967, 312.85647583, 321.54998779,
         330.26803589, 338.92849731, 347.59985352, 356.29257202, 364.75112915,
         373.47167969, 380.78469849],
        [-0.34532937 ,  0.34415907 ,  0.38420755 ,  0.00476420 ,  41.15870667,
         -229.36238098, -220.97476196, -212.63363647, -204.17485046, -195.65614319,
         -187.08787537, -178.53053284, -170.00151062, -161.43228149, -152.85386658,
         -144.27265930, -135.67857361, -127.05997467, -118.44338989, -109.79480743,
         -101.21628571, -92.61004639, -84.01746368, -75.42466736, -66.81315613,
         -58.22586441, -49.61507416, -41.03389740, -32.45261765, -23.87577820,
         -15.29488182, -6.67833805 ,  1.89210987 ,  10.48348618,  19.05630302,
          27.61728096,  36.22502899,  44.80586624,  53.37348557,  61.97953033,
          70.56378174,  79.11846161,  87.71255493,  96.30948639, 104.90310669,
         113.47233582, 122.07431030, 130.69110107, 139.35176086, 147.95822144,
         156.62069702, 165.27755737, 173.90750122, 182.52209473, 191.15945435,
         199.84896851, 208.56604004, 217.23840332, 225.89154053, 234.55934143,
         243.24171448, 251.93775940, 260.62664795, 269.35104370, 278.05526733,
         286.73504639, 295.45010376, 304.19104004, 312.93994141, 321.62597656,
         330.32675171, 338.96481323, 347.61602783, 356.28930664, 364.68209839,
         373.37594604, 380.76406860],
        [-0.41067860 ,  0.40940797 , -0.00003189 ,  0.26897055 ,  68.03149414,
         214.82826233, 217.18968201, 219.47599792, 221.79850769, 224.07466125,
         226.39747620, 228.72122192, 231.05665588, 233.34904480, 235.64422607,
         237.96136475, 240.25602722, 242.57543945, 244.92068481, 247.23052979,
         249.54858398, 251.87034607, 254.19195557, 256.53820801, 258.85311890,
         261.16168213, 263.47125244, 265.76770020, 268.07122803, 270.35028076,
         272.68768311, 275.01422119, 277.34231567, 279.65213013, 281.94415283,
         284.22647095, 286.52343750, 288.84524536, 291.20370483, 293.53698730,
         295.84234619, 298.15939331, 300.48004150, 302.77899170, 305.09832764,
         307.39053345, 309.67019653, 311.96490479, 314.32669067, 316.61538696,
         318.97769165, 321.32031250, 323.65866089, 325.96316528, 328.27459717,
         330.61215210, 332.95809937, 335.28503418, 337.59680176, 339.89752197,
         342.20916748, 344.57836914, 346.90518188, 349.27795410, 351.65844727,
         353.98501587, 356.27581787, 358.62582397, 361.03570557, 363.35931396,
         365.71292114, 368.02539062, 370.28140259, 372.64926147, 374.81219482,
         377.24099731, 380.33154297],
        [-0.72640061 ,  0.72439539 , -0.00167987 ,  0.26911533 ,  68.12149811,
         215.26119995, 217.61218262, 219.89451599, 222.21858215, 224.49700928,
         226.82194519, 229.14736938, 231.48216248, 233.77714539, 236.07197571,
         238.38998413, 240.68261719, 243.00123596, 245.34541321, 247.65528870,
         249.97476196, 252.29608154, 254.61735535, 256.96118164, 259.27450562,
         261.58166504, 263.88882446, 266.18057251, 268.48107910, 270.75704956,
         273.09109497, 275.41641235, 277.73950195, 280.04385376, 282.33142090,
         284.60891724, 286.90274048, 289.21881104, 291.56985474, 293.89788818,
         296.19769287, 298.50851440, 300.82315063, 303.11621094, 305.42971802,
         307.71917725, 309.99530029, 312.28683472, 314.64169312, 316.92556763,
         319.27969360, 321.61541748, 323.94741821, 326.24456787, 328.54986572,
         330.88174438, 333.22195435, 335.54244995, 337.84576416, 340.14077759,
         342.44793701, 344.81179810, 347.13189697, 349.49725342, 351.86975098,
         354.18948364, 356.47027588, 358.81427002, 361.21539307, 363.53390503,
         365.87539673, 368.17858887, 370.41424561, 372.77352905, 374.91915894,
         377.34335327, 380.39047241],
        [-0.81007600 ,  0.80792898 , -0.00094930 ,  0.26850182 ,  68.28794098,
         214.70718384, 217.07141113, 219.37139893, 221.71601868, 224.01040649,
         226.34999084, 228.69126892, 231.03834534, 233.34971619, 235.65794373,
         237.98731995, 240.29164124, 242.62205505, 244.97700500, 247.29676819,
         249.62249756, 251.95004272, 254.27650452, 256.62677002, 258.94677734,
         261.25915527, 263.57202148, 265.86761475, 268.17309570, 270.45516968,
         272.79190063, 275.12139893, 277.44461060, 279.75109863, 282.04144287,
         284.32232666, 286.62048340, 288.94125366, 291.29751587, 293.62939453,
         295.93460083, 298.25152588, 300.57208252, 302.87112427, 305.18969727,
         307.48165894, 309.76046753, 312.05880737, 314.41522217, 316.69967651,
         319.05505371, 321.39447021, 323.73059082, 326.02911377, 328.33883667,
         330.67376709, 333.01876831, 335.34179688, 337.64929199, 339.94601440,
         342.25762939, 344.62371826, 346.94418335, 349.30932617, 351.68157959,
         354.00231934, 356.28839111, 358.63580322, 361.03543091, 363.35610962,
         365.68569946, 367.99496460, 370.22161865, 372.58334351, 374.73937988,
         377.16043091, 380.32135010],
        [-0.29051638 ,  0.29024911 , -0.00159032 ,  0.87392622 ,  68.10390472,
         697.77691650, 693.40374756, 688.92913818, 684.48144531, 679.96618652,
         675.49066162, 671.01171875, 666.53656006, 662.03112793, 657.51007080,
         653.01464844, 648.49523926, 644.00335693, 639.53759766, 635.03039551,
         630.52117920, 626.01373291, 621.50445557, 617.02288818, 612.51068115,
         607.99151611, 603.46875000, 598.93011475, 594.40301514, 589.85253906,
         585.35565186, 580.85375977, 576.33996582, 571.81390381, 567.27282715,
         562.72192383, 558.18695068, 553.67047119, 549.20141602, 544.69360352,
         540.17108154, 535.65930176, 531.14739990, 526.61334229, 522.10424805,
         517.56579590, 513.01428223, 508.47760010, 504.00073242, 499.44625854,
         494.96115112, 490.47192383, 485.97872925, 481.43823242, 476.90655518,
         472.40371704, 467.91574097, 463.40844727, 458.87954712, 454.33596802,
         449.81726074, 445.34948730, 440.83154297, 436.35626221, 431.89852905,
         427.37966919, 422.83804321, 418.35394287, 413.91729736, 409.40515137,
         404.91085815, 400.40661621, 395.77481079, 391.30966187, 386.60757446,
         382.19287109, 378.51470947],
        [-0.00576371 ,  0.00659836 ,  0.00746224 ,  0.86676866 ,  67.92241669,
         694.64215088, 690.35174561, 685.93328857, 681.53295898, 677.07562256,
         672.65728760, 668.23492432, 663.81695557, 659.37127686, 654.90979004,
         650.48199463, 646.03204346, 641.61444092, 637.22253418, 632.80291748,
         628.37811279, 623.94818115, 619.52618408, 615.11804199, 610.69097900,
         606.24737549, 601.80572510, 597.35253906, 592.90411377, 588.43560791,
         584.01354980, 579.59887695, 575.16503906, 570.70684814, 566.24652100,
         561.77160645, 557.31817627, 552.87500000, 548.46685791, 544.04492188,
         539.58728027, 535.14367676, 530.70050049, 526.23522949, 521.80218506,
         517.32104492, 512.82489014, 508.34536743, 503.92623901, 499.43109131,
         495.00814819, 490.56948853, 486.12579346, 481.62396240, 477.13339233,
         472.67321777, 468.21511841, 463.74191284, 459.24081421, 454.73092651,
         450.23117065, 445.77410889, 441.26367188, 436.77929688, 432.30300903,
         427.78585815, 423.23910522, 418.74719238, 414.27862549, 409.76409912,
         405.24859619, 400.73580933, 396.12884521, 391.66561890, 387.06356812,
         382.65826416, 382.34899902],
        [-0.31965679 ,  0.31929627 ,  0.00110467 ,  0.87302345 ,  68.03258514,
         697.89147949, 693.51721191, 689.04290771, 684.59423828, 680.07824707,
         675.60180664, 671.12103271, 666.64471436, 662.13751221, 657.61572266,
         653.11883545, 648.59808350, 644.10675049, 639.64038086, 635.13464355,
         630.62750244, 626.12127686, 621.61425781, 617.13397217, 612.62438965,
         608.10754395, 603.58758545, 599.05169678, 594.52661133, 589.97894287,
         585.48339844, 580.98345947, 576.47259521, 571.94738770, 567.40795898,
         562.85687256, 558.32464600, 553.81072998, 549.34069824, 544.83691406,
         540.31347656, 535.80187988, 531.29156494, 526.75878906, 522.24902344,
         517.71166992, 513.15954590, 508.62432861, 504.14721680, 499.59600830,
         495.11145020, 490.62182617, 486.12734985, 481.58853149, 477.05862427,
         472.55737305, 468.07015991, 463.56161499, 459.03399658, 454.49346924,
         449.97518921, 445.50671387, 440.98974609, 436.51455688, 432.05389404,
         427.53680420, 422.99224854, 418.50610352, 414.06573486, 409.55181885,
         405.04916382, 400.53988647, 395.90856934, 391.44165039, 386.74603271,
         382.33010864, 378.99044800],
        [ 0.07455742 , -0.07334602 ,  0.39454764 ,  0.99547267 ,  40.40071106,
         1067.76831055, 1057.96215820, 1048.22729492, 1038.51684570, 1028.75830078,
         1019.04309082, 1009.28332520, 999.51672363, 989.78887939, 980.04296875,
         970.30590820, 960.57958984, 950.88720703, 941.19433594, 931.49438477,
         921.75988770, 912.03833008, 902.31927490, 892.61517334, 882.93634033,
         873.23608398, 863.54760742, 853.82525635, 844.10662842, 834.38940430,
         824.65740967, 814.95111084, 805.21063232, 795.49798584, 785.77880859,
         776.03155518, 766.33703613, 756.62744141, 746.92089844, 737.22479248,
         727.49670410, 717.75238037, 708.03106689, 698.31121826, 688.58679199,
         678.82556152, 669.07934570, 659.36187744, 649.65527344, 639.88195801,
         630.15380859, 620.40307617, 610.64776611, 600.84570312, 591.06762695,
         581.31811523, 571.59722900, 561.83294678, 552.04724121, 542.25549316,
         532.46020508, 522.68090820, 512.86572266, 503.07910156, 493.28512573,
         483.46408081, 473.64993286, 463.86849976, 454.08547974, 444.25100708,
         434.44918823, 424.60458374, 414.79748535, 405.00497437, 395.17172241,
         385.43466187, 376.48547363],
        [-0.03752957 ,  0.03828224 ,  0.41018096 ,  0.99468088 ,  39.48040009,
         1086.00512695, 1075.91198730, 1065.83984375, 1055.81005859, 1045.75695801,
         1035.73913574, 1025.70690918, 1015.66418457, 1005.66235352, 995.64117432,
         985.62591553, 975.62756348, 965.65972900, 955.69561768, 945.72723389,
         935.71807861, 925.72833252, 915.73748779, 905.76385498, 895.81085205,
         885.83331299, 875.87072754, 865.87841797, 855.88659668, 845.89831543,
         835.89935303, 825.92456055, 815.92163086, 805.94140625, 795.95269775,
         785.93420410, 775.97155762, 765.99053955, 756.01483154, 746.05224609,
         736.05627441, 726.04272461, 716.05487061, 706.07214355, 696.08099365,
         686.05523682, 676.04962158, 666.07598877, 656.10864258, 646.07855225,
         636.08483887, 626.07934570, 616.07196045, 606.01892090, 595.98834229,
         585.98876953, 576.01861572, 566.00604248, 555.97088623, 545.93383789,
         535.89184570, 525.86999512, 515.81756592, 505.79135132, 495.75280762,
         485.69030762, 475.63836670, 465.61627197, 455.59436035, 445.52136230,
         435.47381592, 425.39047241, 415.33776855, 405.29879761, 395.19256592,
         385.20031738, 375.88513184],
        [ 0.17124104 , -0.16978368 ,  0.43160471 ,  0.99431741 ,  37.68833160,
         1111.33093262, 1100.93151855, 1090.52978516, 1080.20104980, 1069.81665039,
         1059.48046875, 1049.08728027, 1038.69274902, 1028.34228516, 1017.98071289,
         1007.60839844, 997.24450684, 986.92395020, 976.59228516, 966.25561523,
         955.87915039, 945.51141357, 935.14770508, 924.80004883, 914.48101807,
         904.14605713, 893.81237793, 883.44708252, 873.09161377, 862.72821045,
         852.35546875, 842.01141357, 831.62939453, 821.28149414, 810.92584229,
         800.53900146, 790.20294189, 779.86090088, 769.52050781, 759.19104004,
         748.82421875, 738.44903564, 728.10205078, 717.75939941, 707.40148926,
         697.00305176, 686.63433838, 676.29516602, 665.97991943, 655.58953857,
         645.23370361, 634.87048340, 624.49670410, 614.08905029, 603.70233154,
         593.35516357, 583.03149414, 572.65771484, 562.26654053, 551.88269043,
         541.49017334, 531.11938477, 520.71899414, 510.35568237, 499.98645020,
         489.59848022, 479.22079468, 468.87136841, 458.53228760, 448.12460327,
         437.76074219, 427.35095215, 416.98913574, 406.62875366, 396.18240356,
         385.84442139, 376.15112305]])
    scores = np.array([0.65393746, 0.81727326, 0.66585314, 0.69425476, 0.81012094, 0.83451980,
        0.64124352, 0.50309050, 0.65451682, 0.46309140, 0.51894391, 0.41556057])
    print(pred.shape)
    print(scores.shape)

    torch_pred = torch.from_numpy(pred).cuda()
    torch_scores = torch.from_numpy(scores).cuda()
    print(torch_scores.argsort(descending=True))
    torch_keep, num_to_keep, _  = nms(torch_pred, torch_scores, 50, 4)
    keep = lane_nms_torch(torch_pred, torch_scores, 50, 4)
    print("offical keep:",torch_keep[:num_to_keep])
    print("my keep:",keep)
    

result:

(12, 77)
(12,)
tensor([ 5,  1,  4,  3,  2,  8,  0,  6, 10,  7,  9, 11], device='cuda:0')
tensor([ 5,  1,  4,  3,  2,  8,  0,  6, 10,  7,  9, 11], device='cuda:0')
offical keep: tensor([ 5,  1,  8, 10], device='cuda:0')
my keep: tensor([5, 3, 0, 7], device='cuda:0')

khan-yin avatar Apr 02 '23 08:04 khan-yin

it seems that return torch.Tensor(1e9) should be return torch.tensor(1e9)

mayunfei0627 avatar Apr 18 '23 08:04 mayunfei0627