-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
121 lines (107 loc) · 3.13 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package main
import (
"flag"
"fmt"
"log"
"strings"
"time"
"github.com/dev6699/yolotriton"
)
type Flags struct {
ModelName string
ModelVersion string
ModelType string
MinProbability float64
MaxIOU float64
URL string
Image string
Benchmark bool
BenchmarkCount int
}
func parseFlags() Flags {
var flags Flags
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
flag.IntVar(&flags.BenchmarkCount, "n", 1, "Number of benchmark run.")
flag.Parse()
return flags
}
func main() {
FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS)
cfg := yolotriton.YoloTritonConfig{
ModelName: FLAGS.ModelName,
ModelVersion: FLAGS.ModelVersion,
MinProbability: float32(FLAGS.MinProbability),
MaxIOU: FLAGS.MaxIOU,
Classes: yolotriton.YoloClasses,
}
var model yolotriton.Model
switch yolotriton.ModelType(FLAGS.ModelType) {
case yolotriton.ModelTypeYoloV8:
cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloV8(cfg)
case yolotriton.ModelTypeYoloNAS:
cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloNAS(cfg)
case yolotriton.ModelTypeYoloNASInt8:
model = yolotriton.NewYoloNASInt8(cfg)
default:
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
}
yt, err := yolotriton.New(FLAGS.URL, model)
if err != nil {
log.Fatal(err)
}
img, err := yolotriton.LoadImage(FLAGS.Image)
if err != nil {
log.Fatalf("Failed to preprocess image: %v", err)
}
loop := 1
if FLAGS.Benchmark {
loop = FLAGS.BenchmarkCount
}
start := time.Now()
for i := 0; i < loop; i++ {
now := time.Now()
results, err := yt.Infer(img)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%d. processing time: %s\n", i+1, time.Since(now))
if FLAGS.Benchmark {
continue
}
for i, r := range results {
fmt.Println("prediction: ", i)
fmt.Println("class: ", r.Class)
fmt.Printf("confidence: %.2f\n", r.Probability)
fmt.Println("bboxes: [", int(r.X1), int(r.Y1), int(r.X2), int(r.Y2), "]")
fmt.Println("---------------------")
}
out, err := yolotriton.DrawBoundingBoxes(
img,
results,
int(float64(img.Bounds().Dx())*0.005),
float64(img.Bounds().Dx())*0.02,
)
if err != nil {
log.Fatal(err)
}
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_%s_out.jpg", strings.Split(FLAGS.Image, ".")[0], FLAGS.ModelName))
if err != nil {
log.Fatal(err)
}
}
if FLAGS.Benchmark {
fmt.Println("Avg processing time:", time.Since(start)/time.Duration(FLAGS.BenchmarkCount))
}
}