Skip to content

Commit df39a44

Browse files
committed
tests passing
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent d942098 commit df39a44

File tree

9 files changed

+554
-3
lines changed

9 files changed

+554
-3
lines changed

go.mod

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ require (
4040
golang.org/x/sync v0.12.0
4141
google.golang.org/api v0.228.0
4242
gopkg.in/natefinch/lumberjack.v2 v2.2.1
43+
gorm.io/driver/postgres v1.5.7
44+
gorm.io/gorm v1.25.7
4345
k8s.io/apimachinery v0.29.0
4446
)
4547

@@ -62,6 +64,11 @@ require (
6264
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
6365
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
6466
github.com/hashicorp/hcl v1.0.0 // indirect
67+
github.com/jackc/pgpassfile v1.0.0 // indirect
68+
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
69+
github.com/jackc/pgx/v5 v5.4.3 // indirect
70+
github.com/jinzhu/inflection v1.0.0 // indirect
71+
github.com/jinzhu/now v1.1.5 // indirect
6572
github.com/klauspost/compress v1.17.11 // indirect
6673
github.com/kylelemons/godebug v1.1.0 // indirect
6774
github.com/magiconair/properties v1.8.7 // indirect

go.sum

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
8888
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
8989
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
9090
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
91+
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
92+
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
93+
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
94+
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
95+
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
96+
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
97+
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
98+
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
99+
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
100+
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
91101
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
92102
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
93103
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
@@ -263,5 +273,9 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
263273
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
264274
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
265275
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
276+
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
277+
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
278+
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
279+
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
266280
k8s.io/apimachinery v0.29.0 h1:+ACVktwyicPz0oc6MTMLwa2Pw3ouLAfAon1wPLtG48o=
267281
k8s.io/apimachinery v0.29.0/go.mod h1:eVBxQ/cwiJxH58eK/jd/vAk4mrxmVlnpBH5J2GbMeis=

pkg/apis/options/options.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ func NewFlagSet() *pflag.FlagSet {
156156
flagSet.Bool("redis-use-cluster", false, "Connect to redis cluster. Must set --redis-cluster-connection-urls to use this feature")
157157
flagSet.StringSlice("redis-cluster-connection-urls", []string{}, "List of Redis cluster connection URLs (eg redis://[USER[:PASSWORD]@]HOST[:PORT]). Used in conjunction with --redis-use-cluster")
158158
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
159+
flagSet.String("postgres-connection-url", "", "URL of postgres server for postgres session storage (eg: postgres://[USER[:PASSWORD]@]HOST[:PORT]/DBNAME)")
160+
flagSet.Int("postgres-max-idle-conns", 10, "Maximum number of idle connections to postgres")
161+
flagSet.Int("postgres-max-open-conns", 100, "Maximum number of open connections to postgres")
162+
flagSet.Int("postgres-conn-max-lifetime", 3600, "Maximum lifetime of a connection to postgres in seconds")
159163
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
160164
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
161165

pkg/apis/options/sessions.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package options
22

33
// SessionOptions contains configuration options for the SessionStore providers.
44
type SessionOptions struct {
5-
Type string `flag:"session-store-type" cfg:"session_store_type"`
6-
Cookie CookieStoreOptions `cfg:",squash"`
7-
Redis RedisStoreOptions `cfg:",squash"`
5+
Type string `flag:"session-store-type" cfg:"session_store_type"`
6+
Cookie CookieStoreOptions `cfg:",squash"`
7+
Redis RedisStoreOptions `cfg:",squash"`
8+
Postgres PostgresStoreOptions `cfg:",squash"`
89
}
910

1011
// CookieSessionStoreType is used to indicate the CookieSessionStore should be
@@ -15,6 +16,10 @@ var CookieSessionStoreType = "cookie"
1516
// used for storing sessions.
1617
var RedisSessionStoreType = "redis"
1718

19+
// PostgresSessionStoreType is used to indicate the PostgresSessionStore should be
20+
// used for storing sessions.
21+
var PostgresSessionStoreType = "postgres"
22+
1823
// CookieStoreOptions contains configuration options for the CookieSessionStore.
1924
type CookieStoreOptions struct {
2025
Minimal bool `flag:"session-cookie-minimal" cfg:"session_cookie_minimal"`
@@ -36,6 +41,14 @@ type RedisStoreOptions struct {
3641
IdleTimeout int `flag:"redis-connection-idle-timeout" cfg:"redis_connection_idle_timeout"`
3742
}
3843

44+
// PostgresStoreOptions contains configuration options for the PostgresSessionStore.
45+
type PostgresStoreOptions struct {
46+
ConnectionURL string `flag:"postgres-connection-url" cfg:"postgres_connection_url"`
47+
MaxIdleConns int `flag:"postgres-max-idle-conns" cfg:"postgres_max_idle_conns"`
48+
MaxOpenConns int `flag:"postgres-max-open-conns" cfg:"postgres_max_open_conns"`
49+
ConnMaxLifetime int `flag:"postgres-conn-max-lifetime" cfg:"postgres_conn_max_lifetime"`
50+
}
51+
3952
func sessionOptionsDefaults() SessionOptions {
4053
return SessionOptions{
4154
Type: CookieSessionStoreType,

pkg/sessions/postgres/lock.go

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package postgres
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"time"
8+
9+
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
10+
"gorm.io/gorm"
11+
)
12+
13+
// SessionLock represents a lock record in the database
14+
type SessionLock struct {
15+
Key string `gorm:"primaryKey"`
16+
ExpiresAt time.Time
17+
CreatedAt time.Time
18+
UpdatedAt time.Time
19+
}
20+
21+
// Lock implements the sessions.Lock interface using a custom table.
22+
// This implementation uses a dedicated table to track locks with expiration times.
23+
//
24+
// Important notes about this implementation:
25+
// 1. Locks have built-in expiration functionality
26+
// 2. Locks are automatically cleaned up when they expire
27+
// 3. The lock key is stored directly in the table
28+
// 4. If the database connection is lost, the lock will be automatically released
29+
type Lock struct {
30+
db *gorm.DB
31+
key string
32+
}
33+
34+
// NewLock creates a new PostgreSQL lock instance
35+
func NewLock(db *gorm.DB, key string) sessions.Lock {
36+
return &Lock{
37+
db: db,
38+
key: key,
39+
}
40+
}
41+
42+
// Obtain obtains a lock by inserting a record into the session_lock table.
43+
// If a lock already exists and hasn't expired, it will return ErrLockNotObtained.
44+
func (l *Lock) Obtain(ctx context.Context, expiration time.Duration) error {
45+
// Verify database connection is healthy
46+
if err := l.verifyConnection(ctx); err != nil {
47+
return fmt.Errorf("database connection error: %w", err)
48+
}
49+
50+
// Start a transaction to ensure atomicity
51+
tx := l.db.WithContext(ctx).Begin()
52+
if tx.Error != nil {
53+
return fmt.Errorf("error starting transaction: %w", tx.Error)
54+
}
55+
defer func() {
56+
if r := recover(); r != nil {
57+
tx.Rollback()
58+
}
59+
}()
60+
61+
// Clean up expired locks
62+
if err := tx.Where("expires_at < ?", time.Now()).Delete(&SessionLock{}).Error; err != nil {
63+
tx.Rollback()
64+
return fmt.Errorf("error cleaning up expired locks: %w", err)
65+
}
66+
67+
// Check if lock exists and is valid
68+
var existingLock SessionLock
69+
err := tx.Where("key = ?", l.key).First(&existingLock).Error
70+
if err == nil {
71+
// Lock exists and hasn't expired
72+
tx.Rollback()
73+
return sessions.ErrLockNotObtained
74+
}
75+
if !errors.Is(err, gorm.ErrRecordNotFound) {
76+
tx.Rollback()
77+
return fmt.Errorf("error checking existing lock: %w", err)
78+
}
79+
80+
// Create new lock
81+
lock := &SessionLock{
82+
Key: l.key,
83+
ExpiresAt: time.Now().Add(expiration),
84+
}
85+
if err := tx.Create(lock).Error; err != nil {
86+
tx.Rollback()
87+
return fmt.Errorf("error creating lock: %w", err)
88+
}
89+
90+
return tx.Commit().Error
91+
}
92+
93+
// Peek checks if the lock is still held by checking if it exists and hasn't expired
94+
func (l *Lock) Peek(ctx context.Context) (bool, error) {
95+
// Verify database connection is healthy
96+
if err := l.verifyConnection(ctx); err != nil {
97+
return false, fmt.Errorf("database connection error: %w", err)
98+
}
99+
100+
// Clean up expired locks
101+
if err := l.db.WithContext(ctx).Where("expires_at < ?", time.Now()).Delete(&SessionLock{}).Error; err != nil {
102+
return false, fmt.Errorf("error cleaning up expired locks: %w", err)
103+
}
104+
105+
// Check if lock exists and is valid
106+
var lock SessionLock
107+
err := l.db.WithContext(ctx).Where("key = ?", l.key).First(&lock).Error
108+
if err == nil {
109+
return true, nil
110+
}
111+
if err == gorm.ErrRecordNotFound {
112+
return false, nil
113+
}
114+
return false, fmt.Errorf("error checking lock: %w", err)
115+
}
116+
117+
// Refresh refreshes the lock by updating its expiration time
118+
func (l *Lock) Refresh(ctx context.Context, expiration time.Duration) error {
119+
// Verify database connection is healthy
120+
if err := l.verifyConnection(ctx); err != nil {
121+
return fmt.Errorf("database connection error: %w", err)
122+
}
123+
124+
// Start a transaction to ensure atomicity
125+
tx := l.db.WithContext(ctx).Begin()
126+
if tx.Error != nil {
127+
return fmt.Errorf("error starting transaction: %w", tx.Error)
128+
}
129+
defer func() {
130+
if r := recover(); r != nil {
131+
tx.Rollback()
132+
}
133+
}()
134+
135+
// First verify we hold the lock
136+
var existingLock SessionLock
137+
err := tx.Where("key = ? AND expires_at > ?", l.key, time.Now()).First(&existingLock).Error
138+
if err != nil {
139+
tx.Rollback()
140+
if errors.Is(err, gorm.ErrRecordNotFound) {
141+
return sessions.ErrNotLocked
142+
}
143+
return fmt.Errorf("error checking existing lock: %w", err)
144+
}
145+
146+
// Update lock expiration
147+
result := tx.Model(&SessionLock{}).
148+
Where("key = ?", l.key).
149+
Update("expires_at", time.Now().Add(expiration))
150+
if result.Error != nil {
151+
tx.Rollback()
152+
return fmt.Errorf("error refreshing lock: %w", result.Error)
153+
}
154+
if result.RowsAffected == 0 {
155+
tx.Rollback()
156+
return sessions.ErrNotLocked
157+
}
158+
159+
return tx.Commit().Error
160+
}
161+
162+
// Release releases the lock by deleting the record from the session_lock table
163+
func (l *Lock) Release(ctx context.Context) error {
164+
// Verify database connection is healthy
165+
if err := l.verifyConnection(ctx); err != nil {
166+
return fmt.Errorf("database connection error: %w", err)
167+
}
168+
169+
// Start a transaction to ensure atomicity
170+
tx := l.db.WithContext(ctx).Begin()
171+
if tx.Error != nil {
172+
return fmt.Errorf("error starting transaction: %w", tx.Error)
173+
}
174+
defer func() {
175+
if r := recover(); r != nil {
176+
tx.Rollback()
177+
}
178+
}()
179+
180+
// Clean up expired locks
181+
if err := tx.Where("expires_at < ?", time.Now()).Delete(&SessionLock{}).Error; err != nil {
182+
tx.Rollback()
183+
return fmt.Errorf("error cleaning up expired locks: %w", err)
184+
}
185+
186+
// Delete the lock
187+
result := tx.Where("key = ?", l.key).Delete(&SessionLock{})
188+
if result.Error != nil {
189+
tx.Rollback()
190+
return fmt.Errorf("error releasing lock: %w", result.Error)
191+
}
192+
if result.RowsAffected == 0 {
193+
tx.Rollback()
194+
return sessions.ErrNotLocked
195+
}
196+
197+
return tx.Commit().Error
198+
}
199+
200+
// verifyConnection checks if the database connection is healthy
201+
func (l *Lock) verifyConnection(ctx context.Context) error {
202+
sqlDB, err := l.db.DB()
203+
if err != nil {
204+
return fmt.Errorf("error getting database instance: %w", err)
205+
}
206+
if err := sqlDB.PingContext(ctx); err != nil {
207+
return fmt.Errorf("error pinging database: %w", err)
208+
}
209+
return nil
210+
}

0 commit comments

Comments
 (0)