CLRNet
CLRNet copied to clipboard
nms torch implementation problem
hello author, I want to implement NMS code in Pytorch, I checked the cuda code and wrote devIoU and NMS process,but the keep index is not right in some imgs,maybe my understanding has some mistake, can you tell me the details and mistake in my NMS?Thanks a Lot! 🤣 here is a true prediction example:
import torch
import numpy as np
from clrnet.ops import nms
def devIoU_torch(a,b):
"""
shape: conf,y,x,lenght,72offsets
Compute distance between two lanes.
"""
DATASET_OFFSET = 0
n_strips = 71
n_offsets = 72
start_a = int(a[2] * n_strips - DATASET_OFFSET + 0.5)
start_b = int(b[2] * n_strips - DATASET_OFFSET + 0.5)
start = max(start_a, start_b)
end_a = int(start_a + a[4] - 1 + 0.5 - int((a[4] - 1) < 0)) # - (x<0) trick to adjust for negative numbers (in case length is 0)
end_b = int(start_b + b[4] - 1 + 0.5 - int((b[4] - 1) < 0))
end = min(min(end_a, end_b), n_offsets - 1)
if end < start:
return torch.Tensor(1e9)
dist = torch.abs(a[5+start:5+end+1] - b[5+start:5+end+1]).sum()
return dist / (end - start + 1)
def lane_nms_torch(predictions, scores, nms_overlap_thresh, top_k):
"""
NMS for lane detection.
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets]
scores: paddle.Tensor [num_lanes]
nms_overlap_thresh: float
top_k: int
"""
# sort by scores to get idx
idx = scores.argsort(descending=True)
print(idx)
keep = []
condidates = predictions.clone()
condidates[idx]
while len(condidates) > 0:
keep.append(idx[0])
if len(keep) >= top_k or len(condidates) == 1:
break
ious = []
for i in range(1, len(condidates)):
ious.append(devIoU_torch(condidates[0], condidates[i]))
ious = torch.Tensor(ious)
mask = ious < nms_overlap_thresh
id = torch.where(mask==False)
condidates = condidates[1:][id]
idx = idx[1:][id]
keep = torch.stack(keep)
return keep
if __name__ == "__main__":
# pred [12,77]
pred = np.array([
[-0.31868595 , 0.31770626 , 0.38531917 , 0.00519343 , 41.27337646,
-231.61357117, -223.09695435, -214.59329224, -206.01312256, -197.44003296,
-188.83171082, -180.23338318, -171.65122986, -163.03523254, -154.41949463,
-145.80096436, -137.16973877, -128.51414490, -119.85771942, -111.19036865,
-102.58344269, -93.94864655, -85.31758118, -76.67832184, -68.02027130,
-59.38843155, -50.73698425, -42.11750412, -33.49232101, -24.86630058,
-16.24045563, -7.58664083 , 1.03324342 , 9.67142010 , 18.30059242,
26.90684891, 35.56851578, 44.20260620, 52.84139633, 61.49933624,
70.13185883, 78.74620819, 87.39436340, 96.04907227, 104.69485474,
113.31373596, 121.96540833, 130.64146423, 139.33471680, 147.97837830,
156.66036987, 165.34426880, 174.01264954, 182.65979004, 191.32659912,
200.04383850, 208.78700256, 217.48742676, 226.16868591, 234.86297607,
243.56433105, 252.28915405, 260.99911499, 269.74468994, 278.47253418,
287.18377686, 295.90975952, 304.67031860, 313.44335938, 322.14849854,
330.87860107, 339.55444336, 348.23510742, 356.94555664, 365.43688965,
374.17858887, 381.64529419],
[-0.75002342 , 0.74795848 , 0.38502458 , 0.00486702 , 41.25712585,
-230.89216614, -222.40429688, -213.93872070, -205.38790894, -196.83256531,
-188.24029541, -179.65783691, -171.09535217, -162.50091553, -153.90159607,
-145.30148315, -136.68898010, -128.05361938, -119.41861725, -110.76705933,
-102.17977142, -93.56396484, -84.95408630, -76.33931732, -67.70516968,
-59.09665680, -50.46805573, -41.87142944, -33.27015305, -24.66850281,
-16.06601334, -7.43314743 , 1.16238201 , 9.77579880 , 18.37902641,
26.96254158, 35.59840012, 44.20528793, 52.81388474, 61.44587708,
70.05379486, 78.63992310, 87.26088715, 95.88831329, 104.50799561,
113.10137177, 121.72947693, 130.37780762, 139.05093384, 147.67424011,
156.34182739, 165.00785828, 173.65528870, 182.28471375, 190.93435669,
199.63537598, 208.36180115, 217.04537964, 225.71020508, 234.38877869,
243.07611084, 251.78483582, 260.48208618, 269.21682739, 277.93188477,
286.62890625, 295.34487915, 304.09490967, 312.85647583, 321.54998779,
330.26803589, 338.92849731, 347.59985352, 356.29257202, 364.75112915,
373.47167969, 380.78469849],
[-0.34532937 , 0.34415907 , 0.38420755 , 0.00476420 , 41.15870667,
-229.36238098, -220.97476196, -212.63363647, -204.17485046, -195.65614319,
-187.08787537, -178.53053284, -170.00151062, -161.43228149, -152.85386658,
-144.27265930, -135.67857361, -127.05997467, -118.44338989, -109.79480743,
-101.21628571, -92.61004639, -84.01746368, -75.42466736, -66.81315613,
-58.22586441, -49.61507416, -41.03389740, -32.45261765, -23.87577820,
-15.29488182, -6.67833805 , 1.89210987 , 10.48348618, 19.05630302,
27.61728096, 36.22502899, 44.80586624, 53.37348557, 61.97953033,
70.56378174, 79.11846161, 87.71255493, 96.30948639, 104.90310669,
113.47233582, 122.07431030, 130.69110107, 139.35176086, 147.95822144,
156.62069702, 165.27755737, 173.90750122, 182.52209473, 191.15945435,
199.84896851, 208.56604004, 217.23840332, 225.89154053, 234.55934143,
243.24171448, 251.93775940, 260.62664795, 269.35104370, 278.05526733,
286.73504639, 295.45010376, 304.19104004, 312.93994141, 321.62597656,
330.32675171, 338.96481323, 347.61602783, 356.28930664, 364.68209839,
373.37594604, 380.76406860],
[-0.41067860 , 0.40940797 , -0.00003189 , 0.26897055 , 68.03149414,
214.82826233, 217.18968201, 219.47599792, 221.79850769, 224.07466125,
226.39747620, 228.72122192, 231.05665588, 233.34904480, 235.64422607,
237.96136475, 240.25602722, 242.57543945, 244.92068481, 247.23052979,
249.54858398, 251.87034607, 254.19195557, 256.53820801, 258.85311890,
261.16168213, 263.47125244, 265.76770020, 268.07122803, 270.35028076,
272.68768311, 275.01422119, 277.34231567, 279.65213013, 281.94415283,
284.22647095, 286.52343750, 288.84524536, 291.20370483, 293.53698730,
295.84234619, 298.15939331, 300.48004150, 302.77899170, 305.09832764,
307.39053345, 309.67019653, 311.96490479, 314.32669067, 316.61538696,
318.97769165, 321.32031250, 323.65866089, 325.96316528, 328.27459717,
330.61215210, 332.95809937, 335.28503418, 337.59680176, 339.89752197,
342.20916748, 344.57836914, 346.90518188, 349.27795410, 351.65844727,
353.98501587, 356.27581787, 358.62582397, 361.03570557, 363.35931396,
365.71292114, 368.02539062, 370.28140259, 372.64926147, 374.81219482,
377.24099731, 380.33154297],
[-0.72640061 , 0.72439539 , -0.00167987 , 0.26911533 , 68.12149811,
215.26119995, 217.61218262, 219.89451599, 222.21858215, 224.49700928,
226.82194519, 229.14736938, 231.48216248, 233.77714539, 236.07197571,
238.38998413, 240.68261719, 243.00123596, 245.34541321, 247.65528870,
249.97476196, 252.29608154, 254.61735535, 256.96118164, 259.27450562,
261.58166504, 263.88882446, 266.18057251, 268.48107910, 270.75704956,
273.09109497, 275.41641235, 277.73950195, 280.04385376, 282.33142090,
284.60891724, 286.90274048, 289.21881104, 291.56985474, 293.89788818,
296.19769287, 298.50851440, 300.82315063, 303.11621094, 305.42971802,
307.71917725, 309.99530029, 312.28683472, 314.64169312, 316.92556763,
319.27969360, 321.61541748, 323.94741821, 326.24456787, 328.54986572,
330.88174438, 333.22195435, 335.54244995, 337.84576416, 340.14077759,
342.44793701, 344.81179810, 347.13189697, 349.49725342, 351.86975098,
354.18948364, 356.47027588, 358.81427002, 361.21539307, 363.53390503,
365.87539673, 368.17858887, 370.41424561, 372.77352905, 374.91915894,
377.34335327, 380.39047241],
[-0.81007600 , 0.80792898 , -0.00094930 , 0.26850182 , 68.28794098,
214.70718384, 217.07141113, 219.37139893, 221.71601868, 224.01040649,
226.34999084, 228.69126892, 231.03834534, 233.34971619, 235.65794373,
237.98731995, 240.29164124, 242.62205505, 244.97700500, 247.29676819,
249.62249756, 251.95004272, 254.27650452, 256.62677002, 258.94677734,
261.25915527, 263.57202148, 265.86761475, 268.17309570, 270.45516968,
272.79190063, 275.12139893, 277.44461060, 279.75109863, 282.04144287,
284.32232666, 286.62048340, 288.94125366, 291.29751587, 293.62939453,
295.93460083, 298.25152588, 300.57208252, 302.87112427, 305.18969727,
307.48165894, 309.76046753, 312.05880737, 314.41522217, 316.69967651,
319.05505371, 321.39447021, 323.73059082, 326.02911377, 328.33883667,
330.67376709, 333.01876831, 335.34179688, 337.64929199, 339.94601440,
342.25762939, 344.62371826, 346.94418335, 349.30932617, 351.68157959,
354.00231934, 356.28839111, 358.63580322, 361.03543091, 363.35610962,
365.68569946, 367.99496460, 370.22161865, 372.58334351, 374.73937988,
377.16043091, 380.32135010],
[-0.29051638 , 0.29024911 , -0.00159032 , 0.87392622 , 68.10390472,
697.77691650, 693.40374756, 688.92913818, 684.48144531, 679.96618652,
675.49066162, 671.01171875, 666.53656006, 662.03112793, 657.51007080,
653.01464844, 648.49523926, 644.00335693, 639.53759766, 635.03039551,
630.52117920, 626.01373291, 621.50445557, 617.02288818, 612.51068115,
607.99151611, 603.46875000, 598.93011475, 594.40301514, 589.85253906,
585.35565186, 580.85375977, 576.33996582, 571.81390381, 567.27282715,
562.72192383, 558.18695068, 553.67047119, 549.20141602, 544.69360352,
540.17108154, 535.65930176, 531.14739990, 526.61334229, 522.10424805,
517.56579590, 513.01428223, 508.47760010, 504.00073242, 499.44625854,
494.96115112, 490.47192383, 485.97872925, 481.43823242, 476.90655518,
472.40371704, 467.91574097, 463.40844727, 458.87954712, 454.33596802,
449.81726074, 445.34948730, 440.83154297, 436.35626221, 431.89852905,
427.37966919, 422.83804321, 418.35394287, 413.91729736, 409.40515137,
404.91085815, 400.40661621, 395.77481079, 391.30966187, 386.60757446,
382.19287109, 378.51470947],
[-0.00576371 , 0.00659836 , 0.00746224 , 0.86676866 , 67.92241669,
694.64215088, 690.35174561, 685.93328857, 681.53295898, 677.07562256,
672.65728760, 668.23492432, 663.81695557, 659.37127686, 654.90979004,
650.48199463, 646.03204346, 641.61444092, 637.22253418, 632.80291748,
628.37811279, 623.94818115, 619.52618408, 615.11804199, 610.69097900,
606.24737549, 601.80572510, 597.35253906, 592.90411377, 588.43560791,
584.01354980, 579.59887695, 575.16503906, 570.70684814, 566.24652100,
561.77160645, 557.31817627, 552.87500000, 548.46685791, 544.04492188,
539.58728027, 535.14367676, 530.70050049, 526.23522949, 521.80218506,
517.32104492, 512.82489014, 508.34536743, 503.92623901, 499.43109131,
495.00814819, 490.56948853, 486.12579346, 481.62396240, 477.13339233,
472.67321777, 468.21511841, 463.74191284, 459.24081421, 454.73092651,
450.23117065, 445.77410889, 441.26367188, 436.77929688, 432.30300903,
427.78585815, 423.23910522, 418.74719238, 414.27862549, 409.76409912,
405.24859619, 400.73580933, 396.12884521, 391.66561890, 387.06356812,
382.65826416, 382.34899902],
[-0.31965679 , 0.31929627 , 0.00110467 , 0.87302345 , 68.03258514,
697.89147949, 693.51721191, 689.04290771, 684.59423828, 680.07824707,
675.60180664, 671.12103271, 666.64471436, 662.13751221, 657.61572266,
653.11883545, 648.59808350, 644.10675049, 639.64038086, 635.13464355,
630.62750244, 626.12127686, 621.61425781, 617.13397217, 612.62438965,
608.10754395, 603.58758545, 599.05169678, 594.52661133, 589.97894287,
585.48339844, 580.98345947, 576.47259521, 571.94738770, 567.40795898,
562.85687256, 558.32464600, 553.81072998, 549.34069824, 544.83691406,
540.31347656, 535.80187988, 531.29156494, 526.75878906, 522.24902344,
517.71166992, 513.15954590, 508.62432861, 504.14721680, 499.59600830,
495.11145020, 490.62182617, 486.12734985, 481.58853149, 477.05862427,
472.55737305, 468.07015991, 463.56161499, 459.03399658, 454.49346924,
449.97518921, 445.50671387, 440.98974609, 436.51455688, 432.05389404,
427.53680420, 422.99224854, 418.50610352, 414.06573486, 409.55181885,
405.04916382, 400.53988647, 395.90856934, 391.44165039, 386.74603271,
382.33010864, 378.99044800],
[ 0.07455742 , -0.07334602 , 0.39454764 , 0.99547267 , 40.40071106,
1067.76831055, 1057.96215820, 1048.22729492, 1038.51684570, 1028.75830078,
1019.04309082, 1009.28332520, 999.51672363, 989.78887939, 980.04296875,
970.30590820, 960.57958984, 950.88720703, 941.19433594, 931.49438477,
921.75988770, 912.03833008, 902.31927490, 892.61517334, 882.93634033,
873.23608398, 863.54760742, 853.82525635, 844.10662842, 834.38940430,
824.65740967, 814.95111084, 805.21063232, 795.49798584, 785.77880859,
776.03155518, 766.33703613, 756.62744141, 746.92089844, 737.22479248,
727.49670410, 717.75238037, 708.03106689, 698.31121826, 688.58679199,
678.82556152, 669.07934570, 659.36187744, 649.65527344, 639.88195801,
630.15380859, 620.40307617, 610.64776611, 600.84570312, 591.06762695,
581.31811523, 571.59722900, 561.83294678, 552.04724121, 542.25549316,
532.46020508, 522.68090820, 512.86572266, 503.07910156, 493.28512573,
483.46408081, 473.64993286, 463.86849976, 454.08547974, 444.25100708,
434.44918823, 424.60458374, 414.79748535, 405.00497437, 395.17172241,
385.43466187, 376.48547363],
[-0.03752957 , 0.03828224 , 0.41018096 , 0.99468088 , 39.48040009,
1086.00512695, 1075.91198730, 1065.83984375, 1055.81005859, 1045.75695801,
1035.73913574, 1025.70690918, 1015.66418457, 1005.66235352, 995.64117432,
985.62591553, 975.62756348, 965.65972900, 955.69561768, 945.72723389,
935.71807861, 925.72833252, 915.73748779, 905.76385498, 895.81085205,
885.83331299, 875.87072754, 865.87841797, 855.88659668, 845.89831543,
835.89935303, 825.92456055, 815.92163086, 805.94140625, 795.95269775,
785.93420410, 775.97155762, 765.99053955, 756.01483154, 746.05224609,
736.05627441, 726.04272461, 716.05487061, 706.07214355, 696.08099365,
686.05523682, 676.04962158, 666.07598877, 656.10864258, 646.07855225,
636.08483887, 626.07934570, 616.07196045, 606.01892090, 595.98834229,
585.98876953, 576.01861572, 566.00604248, 555.97088623, 545.93383789,
535.89184570, 525.86999512, 515.81756592, 505.79135132, 495.75280762,
485.69030762, 475.63836670, 465.61627197, 455.59436035, 445.52136230,
435.47381592, 425.39047241, 415.33776855, 405.29879761, 395.19256592,
385.20031738, 375.88513184],
[ 0.17124104 , -0.16978368 , 0.43160471 , 0.99431741 , 37.68833160,
1111.33093262, 1100.93151855, 1090.52978516, 1080.20104980, 1069.81665039,
1059.48046875, 1049.08728027, 1038.69274902, 1028.34228516, 1017.98071289,
1007.60839844, 997.24450684, 986.92395020, 976.59228516, 966.25561523,
955.87915039, 945.51141357, 935.14770508, 924.80004883, 914.48101807,
904.14605713, 893.81237793, 883.44708252, 873.09161377, 862.72821045,
852.35546875, 842.01141357, 831.62939453, 821.28149414, 810.92584229,
800.53900146, 790.20294189, 779.86090088, 769.52050781, 759.19104004,
748.82421875, 738.44903564, 728.10205078, 717.75939941, 707.40148926,
697.00305176, 686.63433838, 676.29516602, 665.97991943, 655.58953857,
645.23370361, 634.87048340, 624.49670410, 614.08905029, 603.70233154,
593.35516357, 583.03149414, 572.65771484, 562.26654053, 551.88269043,
541.49017334, 531.11938477, 520.71899414, 510.35568237, 499.98645020,
489.59848022, 479.22079468, 468.87136841, 458.53228760, 448.12460327,
437.76074219, 427.35095215, 416.98913574, 406.62875366, 396.18240356,
385.84442139, 376.15112305]])
scores = np.array([0.65393746, 0.81727326, 0.66585314, 0.69425476, 0.81012094, 0.83451980,
0.64124352, 0.50309050, 0.65451682, 0.46309140, 0.51894391, 0.41556057])
print(pred.shape)
print(scores.shape)
torch_pred = torch.from_numpy(pred).cuda()
torch_scores = torch.from_numpy(scores).cuda()
print(torch_scores.argsort(descending=True))
torch_keep, num_to_keep, _ = nms(torch_pred, torch_scores, 50, 4)
keep = lane_nms_torch(torch_pred, torch_scores, 50, 4)
print("offical keep:",torch_keep[:num_to_keep])
print("my keep:",keep)
result:
(12, 77)
(12,)
tensor([ 5, 1, 4, 3, 2, 8, 0, 6, 10, 7, 9, 11], device='cuda:0')
tensor([ 5, 1, 4, 3, 2, 8, 0, 6, 10, 7, 9, 11], device='cuda:0')
offical keep: tensor([ 5, 1, 8, 10], device='cuda:0')
my keep: tensor([5, 3, 0, 7], device='cuda:0')
it seems that return torch.Tensor(1e9) should be return torch.tensor(1e9)