Skip to content

Commit 6202639

Browse files
authored
feat: override response header option for HTTP Loader (#417)
* override header if specified * http-loader-override-response-headers * test cases * cleanup
1 parent 3a2809d commit 6202639

File tree

8 files changed

+95
-2
lines changed

8 files changed

+95
-2
lines changed

blob.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ type Blob struct {
5252
contentType string
5353
memory *memory
5454

55-
Stat *Stat
55+
Header http.Header
56+
Stat *Stat
5657
}
5758

5859
// Stat Blob stat attributes

config/config_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func TestBasic(t *testing.T) {
6060
"-imagor-cache-header-ttl", "169h",
6161
"-imagor-cache-header-swr", "167h",
6262
"-http-loader-insecure-skip-verify-transport",
63+
"-http-loader-override-response-headers", "cache-control,content-type",
6364
"-http-loader-base-url", "https://www.example.com/foo.org",
6465
})
6566
app := srv.App.(*imagor.Imagor)
@@ -85,6 +86,7 @@ func TestBasic(t *testing.T) {
8586
httpLoader := app.Loaders[0].(*httploader.HTTPLoader)
8687
assert.True(t, httpLoader.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify)
8788
assert.Equal(t, "https://www.example.com/foo.org", httpLoader.BaseURL.String())
89+
assert.Equal(t, []string{"cache-control", "content-type"}, httpLoader.OverrideResponseHeaders)
8890
}
8991

9092
func TestVersion(t *testing.T) {

config/httpconfig.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ func withHTTPLoader(fs *flag.FlagSet, cb func() (*zap.Logger, bool)) imagor.Opti
1414
var (
1515
httpLoaderForwardHeaders = fs.String("http-loader-forward-headers", "",
1616
"Forward request header to HTTP Loader request by csv e.g. User-Agent,Accept")
17+
httpLoaderOverrideResponseHeaders = fs.String("http-loader-override-response-headers", "",
18+
"Override HTTP Loader response header to image response by csv e.g. Cache-Control,Expires")
1719
httpLoaderForwardClientHeaders = fs.Bool("http-loader-forward-client-headers", false,
1820
"Forward browser client request headers to HTTP Loader request")
1921
httpLoaderForwardAllHeaders = fs.Bool("http-loader-forward-all-headers", false,
@@ -58,6 +60,7 @@ func withHTTPLoader(fs *flag.FlagSet, cb func() (*zap.Logger, bool)) imagor.Opti
5860
*httpLoaderForwardClientHeaders || *httpLoaderForwardAllHeaders),
5961
httploader.WithAccept(*httpLoaderAccept),
6062
httploader.WithForwardHeaders(*httpLoaderForwardHeaders),
63+
httploader.WithOverrideResponseHeaders(*httpLoaderOverrideResponseHeaders),
6164
httploader.WithAllowedSources(*httpLoaderAllowedSources),
6265
httploader.WithAllowedSourceRegexps(*httpLoaderAllowedSourceRegexp),
6366
httploader.WithMaxAllowedSize(*httpLoaderMaxAllowedSize),

imagor.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ func (app *Imagor) ServeHTTP(w http.ResponseWriter, r *http.Request) {
203203
if r.Header.Get("Imagor-Raw") != "" {
204204
w.Header().Set("Content-Security-Policy", "script-src 'none'")
205205
}
206+
if h := blob.Header; h != nil {
207+
for key := range h {
208+
w.Header().Set(key, h.Get(key))
209+
}
210+
}
206211
if checkStatNotModified(w, r, blob.Stat) {
207212
w.WriteHeader(http.StatusNotModified)
208213
return

imagor_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,31 @@ func TestWithRaw(t *testing.T) {
331331
assert.Equal(t, "bar", w.Header().Get("Content-Type"))
332332
}
333333

334+
func TestWithOverrideHeader(t *testing.T) {
335+
app := New(
336+
WithDebug(true),
337+
WithUnsafe(true),
338+
WithLogger(zap.NewExample()),
339+
WithLoaders(loaderFunc(func(r *http.Request, image string) (*Blob, error) {
340+
blob := NewBlobFromBytes([]byte("foo"))
341+
blob.SetContentType("bar")
342+
blob.Header = make(http.Header)
343+
blob.Header.Set("Content-Type", "tada")
344+
blob.Header.Set("Foo", "bar")
345+
blob.Header.Set("asdf", "fghj")
346+
return blob, nil
347+
})),
348+
)
349+
w := httptest.NewRecorder()
350+
app.ServeHTTP(w, httptest.NewRequest(
351+
http.MethodGet, "https://example.com/unsafe/filters:fill(red):raw()/gopher.png", nil))
352+
assert.Equal(t, 200, w.Code)
353+
assert.Equal(t, "foo", w.Body.String())
354+
assert.Equal(t, "script-src 'none'", w.Header().Get("Content-Security-Policy"))
355+
assert.Equal(t, "tada", w.Header().Get("Content-Type"))
356+
assert.Equal(t, "fghj", w.Header().Get("ASDF"))
357+
}
358+
334359
func TestNewBlobFromPathNotFound(t *testing.T) {
335360
loader := loaderFunc(func(r *http.Request, image string) (*Blob, error) {
336361
return NewBlobFromFile("./non-exists-path"), nil

loader/httploader/httploader.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ type HTTPLoader struct {
6363
// OverrideHeaders override image request headers
6464
OverrideHeaders map[string]string
6565

66+
// OverrideResponseHeaders override image response header from HTTP Loader response
67+
OverrideResponseHeaders []string
68+
6669
// AllowedSources list of sources allowed to load from
6770
AllowedSources []AllowedSource
6871

@@ -204,6 +207,14 @@ func (h *HTTPLoader) Get(r *http.Request, image string) (*imagor.Blob, error) {
204207
}
205208
once.Do(func() {
206209
blob.SetContentType(resp.Header.Get("Content-Type"))
210+
if len(h.OverrideResponseHeaders) > 0 {
211+
blob.Header = make(http.Header)
212+
for _, key := range h.OverrideResponseHeaders {
213+
if val := resp.Header.Get(key); val != "" {
214+
blob.Header.Set(key, val)
215+
}
216+
}
217+
}
207218
})
208219
body := resp.Body
209220
size, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)

loader/httploader/httploader_test.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func (t testTransport) RoundTrip(r *http.Request) (w *http.Response, err error)
2323
w = &http.Response{
2424
StatusCode: http.StatusOK,
2525
Body: io.NopCloser(strings.NewReader(res)),
26-
Header: map[string][]string{},
26+
Header: make(http.Header),
2727
}
2828
w.Header.Set("Content-Type", "image/jpeg")
2929
return
@@ -48,6 +48,7 @@ type test struct {
4848
name string
4949
target string
5050
result string
51+
header map[string]string
5152
err string
5253
}
5354

@@ -80,6 +81,11 @@ func doTests(t *testing.T, loader imagor.Loader, tests []test) {
8081
}
8182
assert.Equal(t, tt.err, msg)
8283
}
84+
if tt.header != nil {
85+
for key, val := range tt.header {
86+
assert.Equal(t, val, b.Header.Get(key))
87+
}
88+
}
8389
})
8490
}
8591
}
@@ -492,6 +498,31 @@ func TestWithForwardHeadersOverrideUserAgent(t *testing.T) {
492498
})
493499
}
494500

501+
func TestWithOverrideResponseHeader(t *testing.T) {
502+
doTests(t, New(
503+
WithTransport(roundTripFunc(func(r *http.Request) (w *http.Response, err error) {
504+
res := &http.Response{
505+
StatusCode: http.StatusOK,
506+
Header: map[string][]string{},
507+
Body: io.NopCloser(strings.NewReader("ok")),
508+
}
509+
res.Header.Set("Content-Type", "image/jpeg")
510+
res.Header.Set("Foo", "Bar")
511+
return res, nil
512+
})),
513+
WithOverrideResponseHeaders("foo"),
514+
), []test{
515+
{
516+
name: "user agent",
517+
target: "https://foo.bar/baz",
518+
result: "ok",
519+
header: map[string]string{
520+
"Foo": "Bar",
521+
},
522+
},
523+
})
524+
}
525+
495526
func TestWithForwardClientHeaders(t *testing.T) {
496527
doTests(t, New(
497528
WithTransport(roundTripFunc(func(r *http.Request) (w *http.Response, err error) {

loader/httploader/option.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@ func WithForwardHeaders(headers ...string) Option {
5959
}
6060
}
6161

62+
// WithOverrideResponseHeaders with override selected response headers option
63+
func WithOverrideResponseHeaders(headers ...string) Option {
64+
return func(h *HTTPLoader) {
65+
for _, raw := range headers {
66+
splits := strings.Split(raw, ",")
67+
for _, header := range splits {
68+
header = strings.TrimSpace(header)
69+
if len(header) > 0 {
70+
h.OverrideResponseHeaders = append(h.OverrideResponseHeaders, header)
71+
}
72+
}
73+
}
74+
}
75+
}
76+
6277
// WithForwardClientHeaders with forward browser request headers option
6378
func WithForwardClientHeaders(enabled bool) Option {
6479
return func(h *HTTPLoader) {

0 commit comments

Comments
 (0)