12
12
13
13
namespace YOLOv8WithOpenCVForUnity
14
14
{
15
-
16
- public class YOLOv8ClassPredictor
15
+ /// <summary>
16
+ /// Referring to https://github.com/ultralytics/ultralytics/
17
+ /// </summary>
18
+ public class YOLOv8ClassPredictor : IDisposable
17
19
{
18
20
Size input_size ;
19
21
int backend ;
@@ -38,7 +40,7 @@ public YOLOv8ClassPredictor(string modelFilepath, string classesFilepath, Size i
38
40
39
41
if ( ! string . IsNullOrEmpty ( classesFilepath ) )
40
42
{
41
- classNames = readClassNames ( classesFilepath ) ;
43
+ classNames = ReadClassNames ( classesFilepath ) ;
42
44
}
43
45
44
46
input_size = new Size ( inputSize . width > 0 ? inputSize . width : 224 , inputSize . height > 0 ? inputSize . height : 224 ) ;
@@ -71,7 +73,7 @@ public YOLOv8ClassPredictor(string modelFilepath, string classesFilepath, Size i
71
73
palette . Add ( new Scalar ( 255 , 55 , 199 , 255 ) ) ;
72
74
}
73
75
74
- protected virtual Mat preprocess ( Mat image )
76
+ protected virtual Mat PreProcess ( Mat image )
75
77
{
76
78
// https://github.com/ultralytics/ultralytics/blob/d74a5a9499acf1afd13d970645e5b1cfcadf4a8f/ultralytics/data/augment.py#L1059
77
79
@@ -96,7 +98,7 @@ protected virtual Mat preprocess(Mat image)
96
98
return blob ; // [1, 3, h, w]
97
99
}
98
100
99
- public virtual Mat infer ( Mat image )
101
+ public virtual Mat Infer ( Mat image )
100
102
{
101
103
// cheack
102
104
if ( image . channels ( ) != 3 )
@@ -106,7 +108,7 @@ public virtual Mat infer(Mat image)
106
108
}
107
109
108
110
// Preprocess
109
- Mat input_blob = preprocess ( image ) ;
111
+ Mat input_blob = PreProcess ( image ) ;
110
112
111
113
// Forward
112
114
classification_net . setInput ( input_blob ) ;
@@ -115,7 +117,7 @@ public virtual Mat infer(Mat image)
115
117
classification_net . forward ( output_blob , classification_net . getUnconnectedOutLayersNames ( ) ) ;
116
118
117
119
// Postprocess
118
- Mat results = postprocess ( output_blob , image . size ( ) ) ;
120
+ Mat results = PostProcess ( output_blob , image . size ( ) ) ;
119
121
120
122
input_blob . Dispose ( ) ;
121
123
for ( int i = 0 ; i < output_blob . Count ; i ++ )
@@ -126,7 +128,7 @@ public virtual Mat infer(Mat image)
126
128
return results ; // [1, num_classes]
127
129
}
128
130
129
- protected virtual Mat postprocess ( List < Mat > output_blob , Size original_shape )
131
+ protected virtual Mat PostProcess ( List < Mat > output_blob , Size original_shape )
130
132
{
131
133
Mat output_blob_0 = output_blob [ 0 ] ;
132
134
@@ -135,21 +137,7 @@ protected virtual Mat postprocess(List<Mat> output_blob, Size original_shape)
135
137
return results ; // [1, num_classes]
136
138
}
137
139
138
- protected virtual Mat softmax ( Mat src )
139
- {
140
- Mat dst = src . clone ( ) ;
141
-
142
- Core . MinMaxLocResult result = Core . minMaxLoc ( src ) ;
143
- Scalar max = new Scalar ( result . maxVal ) ;
144
- Core . subtract ( src , max , dst ) ;
145
- Core . exp ( dst , dst ) ;
146
- Scalar sum = Core . sumElems ( dst ) ;
147
- Core . divide ( dst , sum , dst ) ;
148
-
149
- return dst ;
150
- }
151
-
152
- public virtual void visualize ( Mat image , Mat results , bool print_results = false , bool isRGB = false )
140
+ public virtual void Visualize ( Mat image , Mat results , bool print_results = false , bool isRGB = false )
153
141
{
154
142
if ( image . IsDisposed )
155
143
return ;
@@ -162,9 +150,9 @@ public virtual void visualize(Mat image, Mat results, bool print_results = false
162
150
if ( print_results )
163
151
sb = new StringBuilder ( 64 ) ;
164
152
165
- ClassificationData bmData = getBestMatchData ( results ) ;
153
+ ClassificationData bmData = GetBestMatchData ( results ) ;
166
154
int classId = ( int ) bmData . cls ;
167
- string label = getClassLabel ( bmData . cls ) + ", " + bmData . conf . ToString ( "F2" ) ;
155
+ string label = GetClassLabel ( bmData . cls ) + ", " + bmData . conf . ToString ( "F2" ) ;
168
156
169
157
Scalar c = palette [ classId % palette . Count ] ;
170
158
Scalar color = isRGB ? c : new Scalar ( c . val [ 2 ] , c . val [ 1 ] , c . val [ 0 ] , c . val [ 3 ] ) ;
@@ -183,14 +171,14 @@ public virtual void visualize(Mat image, Mat results, bool print_results = false
183
171
// Print results
184
172
if ( print_results )
185
173
{
186
- sb . AppendLine ( "Best match: " + getClassLabel ( bmData . cls ) + ", " + bmData . ToString ( ) ) ;
174
+ sb . AppendLine ( "Best match: " + GetClassLabel ( bmData . cls ) + ", " + bmData . ToString ( ) ) ;
187
175
}
188
176
189
177
if ( print_results )
190
178
Debug . Log ( sb . ToString ( ) ) ;
191
179
}
192
180
193
- public virtual void dispose ( )
181
+ public virtual void Dispose ( )
194
182
{
195
183
if ( classification_net != null )
196
184
classification_net . Dispose ( ) ;
@@ -206,7 +194,7 @@ public virtual void dispose()
206
194
getDataMat = null ;
207
195
}
208
196
209
- protected virtual List < string > readClassNames ( string filename )
197
+ protected virtual List < string > ReadClassNames ( string filename )
210
198
{
211
199
List < string > classNames = new List < string > ( ) ;
212
200
@@ -235,14 +223,18 @@ protected virtual List<string> readClassNames(string filename)
235
223
return classNames ;
236
224
}
237
225
238
- [ StructLayout ( LayoutKind . Sequential ) ]
226
+ [ Serializable ]
227
+ [ StructLayout ( LayoutKind . Sequential , Pack = 1 ) ]
239
228
public readonly struct ClassificationData
240
229
{
241
230
public readonly float cls ;
242
231
public readonly float conf ;
243
232
244
- // sizeof(ClassificationData)
245
- public const int Size = 2 * sizeof ( float ) ;
233
+ // Count of elements
234
+ public const int ELEMENT_COUNT = 2 ;
235
+
236
+ // sizeof(DetectionData)
237
+ public const int DATA_SIZE = ELEMENT_COUNT * sizeof ( float ) ;
246
238
247
239
public ClassificationData ( int cls , float conf )
248
240
{
@@ -252,11 +244,13 @@ public ClassificationData(int cls, float conf)
252
244
253
245
public override string ToString ( )
254
246
{
255
- return "cls:" + cls . ToString ( ) + " conf:" + conf . ToString ( ) ;
247
+ StringBuilder sb = new StringBuilder ( 64 ) ;
248
+ sb . AppendFormat ( "conf:{0} cls:{1}" , conf , cls ) ;
249
+ return sb . ToString ( ) ;
256
250
}
257
251
} ;
258
252
259
- public virtual ClassificationData [ ] getData ( Mat results )
253
+ public virtual ClassificationData [ ] GetData ( Mat results )
260
254
{
261
255
if ( results . empty ( ) )
262
256
return new ClassificationData [ 0 ] ;
@@ -279,20 +273,20 @@ public virtual ClassificationData[] getData(Mat results)
279
273
return dst ;
280
274
}
281
275
282
- public virtual ClassificationData [ ] getSortedData ( Mat results , int topK = 5 )
276
+ public virtual ClassificationData [ ] GetSortedData ( Mat results , int topK = 5 )
283
277
{
284
278
if ( results . empty ( ) )
285
279
return new ClassificationData [ 0 ] ;
286
280
287
281
int num = results . cols ( ) ;
288
282
289
283
if ( topK < 1 || topK > num ) topK = num ;
290
- var sortedData = getData ( results ) . OrderByDescending ( x => x . conf ) . Take ( topK ) . ToArray ( ) ;
284
+ var sortedData = GetData ( results ) . OrderByDescending ( x => x . conf ) . Take ( topK ) . ToArray ( ) ;
291
285
292
286
return sortedData ;
293
287
}
294
288
295
- public virtual ClassificationData getBestMatchData ( Mat results )
289
+ public virtual ClassificationData GetBestMatchData ( Mat results )
296
290
{
297
291
if ( results . empty ( ) )
298
292
return new ClassificationData ( ) ;
@@ -302,7 +296,7 @@ public virtual ClassificationData getBestMatchData(Mat results)
302
296
return new ClassificationData ( ( int ) minmax . maxLoc . x , ( float ) minmax . maxVal ) ;
303
297
}
304
298
305
- public virtual string getClassLabel ( float id )
299
+ public virtual string GetClassLabel ( float id )
306
300
{
307
301
int classId = ( int ) id ;
308
302
string className = string . Empty ;
0 commit comments