-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrest_server.h
More file actions
391 lines (326 loc) · 10.7 KB
/
rest_server.h
File metadata and controls
391 lines (326 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
// Copyright © 2025 MLXR Development
// REST server with OpenAI-compatible API endpoints
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>
namespace mlxr {
// Forward declarations
class LlamaModel;
namespace runtime {
class Tokenizer;
class Engine;
} // namespace runtime
namespace scheduler {
class Scheduler;
}
namespace registry {
class ModelRegistry;
}
using runtime::Engine;
using runtime::Tokenizer;
namespace server {
// Forward declarations
class OllamaAPIHandler;
class SchedulerWorker;
// ==============================================================================
// Request/Response Data Structures
// ==============================================================================
// Chat completion message
struct ChatMessage {
std::string role; // "system", "user", "assistant", "function"
std::string content;
std::optional<std::string> name;
std::optional<std::string> function_call;
};
// Function definition for function calling
struct FunctionDefinition {
std::string name;
std::string description;
std::string parameters_json; // JSON schema
};
// Tool definition
struct ToolDefinition {
std::string type; // "function"
FunctionDefinition function;
};
// Chat completion request (OpenAI-compatible)
struct ChatCompletionRequest {
std::string model;
std::vector<ChatMessage> messages;
// Optional parameters
std::optional<float> temperature;
std::optional<float> top_p;
std::optional<int> top_k;
std::optional<float> repetition_penalty;
std::optional<int> max_tokens;
std::optional<bool> stream;
std::optional<std::vector<std::string>> stop;
std::optional<float> presence_penalty;
std::optional<float> frequency_penalty;
std::optional<int> n;
std::optional<std::string> user;
std::optional<std::vector<ToolDefinition>> tools;
std::optional<std::string> tool_choice;
std::optional<int> seed;
};
// Usage statistics
struct UsageInfo {
int prompt_tokens = 0;
int completion_tokens = 0;
int total_tokens = 0;
};
// Chat completion choice
struct ChatCompletionChoice {
int index = 0;
ChatMessage message;
std::string
finish_reason; // "stop", "length", "function_call", "content_filter"
};
// Chat completion response
struct ChatCompletionResponse {
std::string id;
std::string object = "chat.completion";
int64_t created = 0;
std::string model;
std::vector<ChatCompletionChoice> choices;
UsageInfo usage;
};
// Streaming chunk delta
struct ChatCompletionDelta {
std::optional<std::string> role;
std::optional<std::string> content;
std::optional<std::string> function_call;
};
// Streaming chunk choice
struct ChatCompletionStreamChoice {
int index = 0;
ChatCompletionDelta delta;
std::string finish_reason;
};
// Streaming chunk
struct ChatCompletionChunk {
std::string id;
std::string object = "chat.completion.chunk";
int64_t created = 0;
std::string model;
std::vector<ChatCompletionStreamChoice> choices;
};
// Completion request (non-chat)
struct CompletionRequest {
std::string model;
std::string prompt;
// Optional parameters
std::optional<float> temperature;
std::optional<float> top_p;
std::optional<int> top_k;
std::optional<float> repetition_penalty;
std::optional<int> max_tokens;
std::optional<bool> stream;
std::optional<std::vector<std::string>> stop;
std::optional<float> presence_penalty;
std::optional<float> frequency_penalty;
std::optional<int> n;
std::optional<std::string> suffix;
std::optional<int> seed;
};
// Completion choice
struct CompletionChoice {
int index = 0;
std::string text;
std::string finish_reason;
};
// Completion response
struct CompletionResponse {
std::string id;
std::string object = "text_completion";
int64_t created = 0;
std::string model;
std::vector<CompletionChoice> choices;
UsageInfo usage;
};
// Embedding request
struct EmbeddingRequest {
std::string model;
std::string input; // Could be string or array of strings
std::optional<std::string> encoding_format; // "float" or "base64"
std::optional<std::string> user;
};
// Single embedding
struct EmbeddingObject {
int index = 0;
std::vector<float> embedding;
std::string object = "embedding";
};
// Embedding response
struct EmbeddingResponse {
std::string object = "list";
std::vector<EmbeddingObject> data;
std::string model;
UsageInfo usage;
};
// Model info
struct ModelInfo {
std::string id;
std::string object = "model";
int64_t created = 0;
std::string owned_by = "mlxr";
};
// Model list response
struct ModelListResponse {
std::string object = "list";
std::vector<ModelInfo> data;
};
// Error response
struct ErrorResponse {
struct ErrorDetail {
std::string message;
std::string type;
std::optional<std::string> code;
};
ErrorDetail error;
};
// ==============================================================================
// HTTP Request/Response Structures
// ==============================================================================
struct HttpRequest {
std::string method; // GET, POST, etc.
std::string path;
std::map<std::string, std::string> headers;
std::string body;
std::map<std::string, std::string> query_params;
};
struct HttpResponse {
int status_code = 200;
std::map<std::string, std::string> headers;
std::string body;
};
// ==============================================================================
// REST Server Configuration
// ==============================================================================
struct ServerConfig {
std::string bind_address = "127.0.0.1";
int port = 8080;
bool enable_unix_socket = true;
std::string unix_socket_path =
"~/Library/Application Support/MLXRunner/run/mlxrunner.sock";
bool enable_cors = true;
int max_connections = 100;
int thread_pool_size = 4;
std::string api_key; // Optional API key for authentication
bool enable_metrics = true;
std::string log_level = "info";
// Connection timeout settings
int read_timeout_sec = 30; // Read timeout in seconds (default: 30s)
int write_timeout_sec = 30; // Write timeout in seconds (default: 30s)
int keep_alive_max_count = 100; // Max requests per connection (default: 100)
int keep_alive_timeout_sec = 5; // Keep-alive timeout in seconds (default: 5s)
size_t payload_max_length = 100 * 1024 * 1024; // Max payload size (default: 100MB)
};
// ==============================================================================
// REST Server Class
// ==============================================================================
// Callback type for streaming responses
// Returns true if chunk was sent successfully, false if connection closed
using StreamCallback = std::function<bool(const std::string& chunk)>;
class RestServer {
public:
explicit RestServer(const ServerConfig& config);
~RestServer();
// Delete copy operations
RestServer(const RestServer&) = delete;
RestServer& operator=(const RestServer&) = delete;
// Initialize server
bool initialize();
// Start/stop server
bool start();
void stop();
// Check if server is running
bool is_running() const { return running_; }
// Get server configuration
const ServerConfig& config() const { return config_; }
// Set model and inference engine (legacy - use load_model instead)
void set_model(std::shared_ptr<LlamaModel> model);
void set_tokenizer(std::shared_ptr<Tokenizer> tokenizer);
void set_engine(std::shared_ptr<Engine> engine);
void set_scheduler(std::shared_ptr<scheduler::Scheduler> scheduler);
void set_registry(std::shared_ptr<registry::ModelRegistry> registry);
void set_worker(std::shared_ptr<SchedulerWorker> worker);
// Model loading and management
/**
* @brief Load a model by name
* @param model_name Model name or ID from registry
* @return true if model loaded successfully
*/
bool load_model(const std::string& model_name);
/**
* @brief Unload a model
* @param model_name Model name to unload
* @return true if model unloaded successfully
*/
bool unload_model(const std::string& model_name);
/**
* @brief Get currently loaded model name
*/
std::string current_model() const;
// Endpoint handlers (can be overridden for custom behavior)
virtual HttpResponse handle_chat_completion(const HttpRequest& request);
virtual HttpResponse handle_completion(const HttpRequest& request);
virtual HttpResponse handle_embedding(const HttpRequest& request);
virtual HttpResponse handle_models(const HttpRequest& request);
virtual HttpResponse handle_model_info(const HttpRequest& request);
private:
// Configuration
ServerConfig config_;
// Server state
bool running_;
bool initialized_;
// Model and inference components (legacy)
std::shared_ptr<LlamaModel> model_;
std::shared_ptr<Tokenizer> tokenizer_;
std::shared_ptr<Engine> engine_;
std::shared_ptr<scheduler::Scheduler> scheduler_;
std::shared_ptr<registry::ModelRegistry> registry_;
std::shared_ptr<SchedulerWorker> worker_;
// Model loading and management
std::string current_model_name_;
mutable std::mutex model_mutex_; // Protect model loading/unloading (mutable for const methods)
// API handlers
std::unique_ptr<OllamaAPIHandler> ollama_handler_;
// Request routing
HttpResponse route_request(const HttpRequest& request);
// Request parsing
std::optional<ChatCompletionRequest> parse_chat_completion_request(
const std::string& json);
std::optional<CompletionRequest> parse_completion_request(
const std::string& json);
std::optional<EmbeddingRequest> parse_embedding_request(
const std::string& json);
// Response serialization
std::string serialize_chat_completion_response(
const ChatCompletionResponse& response);
std::string serialize_completion_response(const CompletionResponse& response);
std::string serialize_embedding_response(const EmbeddingResponse& response);
std::string serialize_model_list_response(const ModelListResponse& response);
std::string serialize_error_response(const ErrorResponse& response);
std::string serialize_chat_completion_chunk(const ChatCompletionChunk& chunk);
// Streaming support
void stream_chat_completion(const ChatCompletionRequest& request,
StreamCallback callback);
void stream_completion(const CompletionRequest& request,
StreamCallback callback);
// Utility methods
std::string generate_request_id();
int64_t current_timestamp();
HttpResponse create_error_response(int status_code,
const std::string& message);
bool validate_api_key(const HttpRequest& request);
// Server implementation details (platform-specific)
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace server
} // namespace mlxr