Skip to content

Commit e3f7e12

Browse files
Wondertanclaude
andcommitted
feat(rpc): harden RPC server against DOS attacks
- Set WithMaxRequestSize(5 MiB) on go-jsonrpc server (down from 100 MiB default) - Add http.Server timeouts: ReadTimeout, WriteTimeout, IdleTimeout, MaxHeaderBytes - Add per-IP rate limiting middleware (100 req/s sustained, 200 burst) - Add concurrent connection limiting middleware (500 max) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e52e400 commit e3f7e12

File tree

4 files changed

+259
-10
lines changed

4 files changed

+259
-10
lines changed

api/rpc/middleware.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package rpc
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"sync"
7+
"time"
8+
9+
"golang.org/x/time/rate"
10+
)
11+
12+
// connLimit returns middleware that limits the number of concurrent requests.
13+
// When the limit is reached, new requests receive 503 Service Unavailable.
14+
func connLimit(maxConns int, next http.Handler) http.Handler {
15+
sem := make(chan struct{}, maxConns)
16+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17+
select {
18+
case sem <- struct{}{}:
19+
defer func() { <-sem }()
20+
next.ServeHTTP(w, r)
21+
default:
22+
log.Warnw("connection limit reached, rejecting request", "remote", r.RemoteAddr)
23+
http.Error(w, "server busy, try again later", http.StatusServiceUnavailable)
24+
}
25+
})
26+
}
27+
28+
// rateLimit returns middleware that enforces per-IP rate limiting.
29+
// Requests exceeding the limit receive 429 Too Many Requests.
30+
func rateLimit(rps, burst int, next http.Handler) http.Handler {
31+
var mu sync.Mutex
32+
type entry struct {
33+
limiter *rate.Limiter
34+
lastSeen time.Time
35+
}
36+
limiters := make(map[string]*entry)
37+
rateL := rate.Limit(rps)
38+
39+
// Evict stale entries in the background.
40+
go func() {
41+
ticker := time.NewTicker(5 * time.Minute)
42+
defer ticker.Stop()
43+
for range ticker.C {
44+
mu.Lock()
45+
for ip, e := range limiters {
46+
if time.Since(e.lastSeen) > 10*time.Minute {
47+
delete(limiters, ip)
48+
}
49+
}
50+
mu.Unlock()
51+
}
52+
}()
53+
54+
getLimiter := func(ip string) *rate.Limiter {
55+
mu.Lock()
56+
defer mu.Unlock()
57+
e, ok := limiters[ip]
58+
if !ok {
59+
l := rate.NewLimiter(rateL, burst)
60+
limiters[ip] = &entry{limiter: l, lastSeen: time.Now()}
61+
return l
62+
}
63+
e.lastSeen = time.Now()
64+
return e.limiter
65+
}
66+
67+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68+
ip := extractIP(r)
69+
if !getLimiter(ip).Allow() {
70+
log.Warnw("rate limit exceeded", "ip", ip)
71+
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
72+
return
73+
}
74+
next.ServeHTTP(w, r)
75+
})
76+
}
77+
78+
// extractIP returns the IP portion of RemoteAddr (strips port).
79+
func extractIP(r *http.Request) string {
80+
host, _, err := net.SplitHostPort(r.RemoteAddr)
81+
if err != nil {
82+
return r.RemoteAddr
83+
}
84+
return host
85+
}

api/rpc/middleware_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package rpc
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"sync"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
const testRemoteAddr = "1.2.3.4:1234"
14+
15+
func TestConnLimit_AllowsWithinLimit(t *testing.T) {
16+
handler := connLimit(2, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17+
w.WriteHeader(http.StatusOK)
18+
}))
19+
20+
req := httptest.NewRequest("GET", "/", nil)
21+
w := httptest.NewRecorder()
22+
handler.ServeHTTP(w, req)
23+
assert.Equal(t, http.StatusOK, w.Code)
24+
}
25+
26+
func TestConnLimit_RejectsOverLimit(t *testing.T) {
27+
blocked := make(chan struct{})
28+
released := make(chan struct{})
29+
handler := connLimit(1, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
if r.Header.Get("X-Block") == "true" {
31+
close(blocked)
32+
<-released
33+
}
34+
w.WriteHeader(http.StatusOK)
35+
}))
36+
37+
var wg sync.WaitGroup
38+
wg.Add(1)
39+
go func() {
40+
defer wg.Done()
41+
req := httptest.NewRequest("GET", "/", nil)
42+
req.Header.Set("X-Block", "true")
43+
w := httptest.NewRecorder()
44+
handler.ServeHTTP(w, req)
45+
}()
46+
47+
<-blocked
48+
49+
req := httptest.NewRequest("GET", "/", nil)
50+
w := httptest.NewRecorder()
51+
handler.ServeHTTP(w, req)
52+
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
53+
54+
close(released)
55+
wg.Wait()
56+
}
57+
58+
func TestConnLimit_ReleasesSlotAfterRequest(t *testing.T) {
59+
handler := connLimit(1, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60+
w.WriteHeader(http.StatusOK)
61+
}))
62+
63+
req := httptest.NewRequest("GET", "/", nil)
64+
w := httptest.NewRecorder()
65+
handler.ServeHTTP(w, req)
66+
require.Equal(t, http.StatusOK, w.Code)
67+
68+
w = httptest.NewRecorder()
69+
handler.ServeHTTP(w, req)
70+
assert.Equal(t, http.StatusOK, w.Code)
71+
}
72+
73+
func TestRateLimit_AllowsWithinLimit(t *testing.T) {
74+
handler := rateLimit(100, 10, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
75+
w.WriteHeader(http.StatusOK)
76+
}))
77+
78+
req := httptest.NewRequest("GET", "/", nil)
79+
req.RemoteAddr = testRemoteAddr
80+
w := httptest.NewRecorder()
81+
handler.ServeHTTP(w, req)
82+
assert.Equal(t, http.StatusOK, w.Code)
83+
}
84+
85+
func TestRateLimit_RejectsOverBurst(t *testing.T) {
86+
handler := rateLimit(1, 3, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
87+
w.WriteHeader(http.StatusOK)
88+
}))
89+
90+
req := httptest.NewRequest("GET", "/", nil)
91+
req.RemoteAddr = testRemoteAddr
92+
93+
for i := 0; i < 3; i++ {
94+
w := httptest.NewRecorder()
95+
handler.ServeHTTP(w, req)
96+
assert.Equal(t, http.StatusOK, w.Code, "request %d should succeed", i)
97+
}
98+
99+
w := httptest.NewRecorder()
100+
handler.ServeHTTP(w, req)
101+
assert.Equal(t, http.StatusTooManyRequests, w.Code)
102+
}
103+
104+
func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
105+
handler := rateLimit(1, 1, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
106+
w.WriteHeader(http.StatusOK)
107+
}))
108+
109+
req1 := httptest.NewRequest("GET", "/", nil)
110+
req1.RemoteAddr = testRemoteAddr
111+
w := httptest.NewRecorder()
112+
handler.ServeHTTP(w, req1)
113+
assert.Equal(t, http.StatusOK, w.Code)
114+
115+
w = httptest.NewRecorder()
116+
handler.ServeHTTP(w, req1)
117+
assert.Equal(t, http.StatusTooManyRequests, w.Code)
118+
119+
req2 := httptest.NewRequest("GET", "/", nil)
120+
req2.RemoteAddr = "5.6.7.8:5678"
121+
w = httptest.NewRecorder()
122+
handler.ServeHTTP(w, req2)
123+
assert.Equal(t, http.StatusOK, w.Code)
124+
}
125+
126+
func TestExtractIP(t *testing.T) {
127+
tests := []struct {
128+
remoteAddr string
129+
expected string
130+
}{
131+
{testRemoteAddr, "1.2.3.4"},
132+
{"[::1]:1234", "::1"},
133+
{"1.2.3.4", "1.2.3.4"},
134+
}
135+
for _, tt := range tests {
136+
r := &http.Request{RemoteAddr: tt.remoteAddr}
137+
assert.Equal(t, tt.expected, extractIP(r))
138+
}
139+
}

api/rpc/server.go

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ import (
2121

2222
var log = logging.Logger("rpc")
2323

24+
// TODO(@Wondertan): Expose in config if requested
25+
const (
26+
// maxRequestSize is 5 MiB, significantly lower than go-jsonrpc's 100 MiB default.
27+
maxRequestSize = 5 << 20
28+
// maxConcurrentConns caps simultaneous connections to bound goroutine/FD usage.
29+
maxConcurrentConns = 500
30+
// maxRequestsPerSecond is the per-IP sustained rate.
31+
maxRequestsPerSecond = 100
32+
// maxRequestBurst is the per-IP burst allowance.
33+
maxRequestBurst = 200
34+
)
35+
2436
type CORSConfig struct {
2537
Enabled bool
2638
AllowedOrigins []string
@@ -59,7 +71,10 @@ func NewServer(
5971
signer jwt.Signer,
6072
verifier jwt.Verifier,
6173
) *Server {
62-
rpc := jsonrpc.NewServer()
74+
rpc := jsonrpc.NewServer(
75+
jsonrpc.WithMaxRequestSize(maxRequestSize),
76+
)
77+
6378
srv := &Server{
6479
rpc: rpc,
6580
signer: signer,
@@ -76,23 +91,33 @@ func NewServer(
7691
Handler: srv.newHandlerStack(rpc),
7792
// the amount of time allowed to read request headers. set to the default 2 seconds
7893
ReadHeaderTimeout: 2 * time.Second,
94+
ReadTimeout: 30 * time.Second,
95+
WriteTimeout: 60 * time.Second,
96+
IdleTimeout: 120 * time.Second,
97+
MaxHeaderBytes: 1 << 20, // 1 MiB
7998
}
8099

81100
return srv
82101
}
83102

84-
// newHandlerStack returns wrapped rpc related handlers
103+
// newHandlerStack returns wrapped rpc related handlers.
104+
// Middleware order (outermost first): rate-limit → conn-limit → CORS/auth → RPC handler.
85105
func (s *Server) newHandlerStack(core http.Handler) http.Handler {
86-
if s.authDisabled {
106+
var h http.Handler
107+
switch {
108+
case s.authDisabled:
87109
log.Warn("auth disabled, allowing all origins, methods and headers for CORS")
88-
return s.corsAny(core)
89-
}
90-
91-
if s.corsConfig.Enabled {
92-
return s.corsWithConfig(s.authHandler(core))
110+
h = s.corsAny(core)
111+
case s.corsConfig.Enabled:
112+
h = s.corsWithConfig(s.authHandler(core))
113+
default:
114+
h = s.authHandler(core)
93115
}
94116

95-
return s.authHandler(core)
117+
// Apply connection and rate limiting as outermost layers.
118+
h = connLimit(maxConcurrentConns, h)
119+
h = rateLimit(maxRequestsPerSecond, maxRequestBurst, h)
120+
return h
96121
}
97122

98123
// verifyAuth is the RPC server's auth middleware. This middleware is only

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ require (
7777
golang.org/x/crypto v0.49.0
7878
golang.org/x/sync v0.20.0
7979
golang.org/x/text v0.35.0
80+
golang.org/x/time v0.15.0
8081
google.golang.org/grpc v1.79.3
8182
google.golang.org/protobuf v1.36.11
8283
)
@@ -375,7 +376,6 @@ require (
375376
golang.org/x/sys v0.42.0 // indirect
376377
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect
377378
golang.org/x/term v0.41.0 // indirect
378-
golang.org/x/time v0.15.0 // indirect
379379
golang.org/x/tools v0.43.0 // indirect
380380
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
381381
gonum.org/v1/gonum v0.17.0 // indirect

0 commit comments

Comments
 (0)