-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple_generation.cpp
More file actions
77 lines (65 loc) · 2.13 KB
/
simple_generation.cpp
File metadata and controls
77 lines (65 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
* @file simple_generation.cpp
* @brief Simple example of text generation with MLXR
*
* Usage:
* ./simple_generation <model_dir> <tokenizer_path> <prompt>
*
* Example:
* ./simple_generation ./models/TinyLlama-1.1B ./models/tokenizer.model "Once
* upon a time"
*/
#include <iostream>
#include <string>
#include "runtime/engine.h"
int main(int argc, char* argv[]) {
// Parse command line arguments
if (argc < 4) {
std::cerr << "Usage: " << argv[0]
<< " <model_dir> <tokenizer_path> <prompt>" << std::endl;
std::cerr << "\nExample:" << std::endl;
std::cerr << " " << argv[0]
<< " ./models/TinyLlama-1.1B ./models/tokenizer.model \"Once "
"upon a time\""
<< std::endl;
return 1;
}
std::string model_dir = argv[1];
std::string tokenizer_path = argv[2];
std::string prompt = argv[3];
std::cout << "=== MLXR Simple Generation Example ===" << std::endl;
std::cout << "Model directory: " << model_dir << std::endl;
std::cout << "Tokenizer: " << tokenizer_path << std::endl;
std::cout << "Prompt: \"" << prompt << "\"" << std::endl;
std::cout << std::endl;
// Configure generation
mlxr::runtime::GenerationConfig config;
config.max_new_tokens = 50;
config.sampler_config.temperature = 0.7f;
config.sampler_config.top_p = 0.9f;
config.echo_prompt = true;
config.verbose = true;
// Load engine
std::cout << "Loading model..." << std::endl;
auto engine = mlxr::runtime::load_engine(model_dir, tokenizer_path, config);
if (!engine) {
std::cerr << "Failed to load engine" << std::endl;
return 1;
}
std::cout << "Model loaded successfully!" << std::endl;
std::cout << std::endl;
// Generate text
std::cout << "Generating..." << std::endl;
std::cout << "---" << std::endl;
try {
std::string generated = engine->generate(prompt);
std::cout << std::endl;
std::cout << "---" << std::endl;
std::cout << "\nGenerated text:" << std::endl;
std::cout << generated << std::endl;
} catch (const std::exception& e) {
std::cerr << "Generation failed: " << e.what() << std::endl;
return 1;
}
return 0;
}