Skip to content

Commit

Permalink
Add '--rpc' command line option
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerganov committed Apr 23, 2024
1 parent 89f4d6e commit 2ef868d
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 10 deletions.
9 changes: 9 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
#endif // GGML_USE_CUDA_SYCL_VULKAN
return true;
}
if (arg == "--rpc") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.rpc_servers = argv[i];
return true;
}
if (arg == "--no-mmap") {
params.use_mmap = false;
return true;
Expand Down Expand Up @@ -1507,6 +1515,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
}
printf(" --rpc SERVERS comma separated list of RPC servers\n");
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n");
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
std::string rpc_servers = ""; // comma separated list of RPC servers

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
Expand Down
1 change: 1 addition & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ int main(int argc, char ** argv) {
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
llama_rpc_init(params.rpc_servers.empty() ? nullptr : params.rpc_servers.c_str());

llama_model * model;
llama_context * ctx;
Expand Down
31 changes: 21 additions & 10 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "ggml-backend-impl.h"
#include "ggml-rpc.grpc.pb.h"

#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

Expand Down Expand Up @@ -129,7 +129,9 @@ static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const ggml::T

GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
UNUSED(buffer);
GGML_ASSERT(!ggml_is_quantized(tensor->type) && "quantized tensors not supported");
if (ggml_is_quantized(tensor->type)) {
GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
}
}

GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
Expand Down Expand Up @@ -344,11 +346,20 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .event_synchronize = */ NULL,
};

// TODO: this should be read from the command line or some configuration file
static std::vector<std::string> SERVER_ENDPOINTS = {
"localhost:50051",
"localhost:50052",
};
static std::vector<std::string> endpoints;

GGML_API GGML_CALL void ggml_rpc_init(const char * rpc_servers) {
endpoints.clear();
GGML_ASSERT(rpc_servers != NULL);
std::string servers(rpc_servers);
size_t pos = 0;
while ((pos = servers.find(",")) != std::string::npos) {
std::string server = servers.substr(0, pos);
endpoints.push_back(server);
servers.erase(0, pos + 1);
}
endpoints.push_back(servers);
}

static ggml_backend_t instances[GGML_RPC_MAX_SERVERS] = {0};

Expand All @@ -364,7 +375,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(int server_id) {
if (instances[server_id]) {
return instances[server_id];
}
std::string endpoint = SERVER_ENDPOINTS[server_id];
std::string endpoint = endpoints[server_id];
GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str());
auto channel = grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials());
std::shared_ptr<ggml::Backend::Stub> stub = ggml::Backend::NewStub(channel);
Expand Down Expand Up @@ -400,14 +411,14 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
}

GGML_API GGML_CALL int ggml_backend_rpc_get_server_count(void) {
return SERVER_ENDPOINTS.size();
return endpoints.size();
}

// Server-side implementation of the RPC backend

BackendImpl::BackendImpl() {
// the RPC backend simply delegates to one of the existing backends
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
fprintf(stderr, "%s: using CUDA backend\n", __func__);
backend = ggml_backend_cuda_init(0); // init device 0
if (!backend) {
Expand Down
2 changes: 2 additions & 0 deletions ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ extern "C" {

#define GGML_RPC_MAX_SERVERS 16

GGML_API GGML_CALL void ggml_rpc_init(const char * rpc_servers);

// backend API
GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(int server_id);
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
Expand Down
10 changes: 10 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14957,6 +14957,16 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
}
}

void llama_rpc_init(const char * rpc_servers) {
#ifdef GGML_USE_RPC
ggml_rpc_init(rpc_servers);
#else
if (rpc_servers != nullptr) {
LLAMA_LOG_WARN("%s: RPC support is not enabled in this build\n", __func__);
}
#endif
}

void llama_backend_free(void) {
#ifdef GGML_USE_MPI
ggml_mpi_backend_free();
Expand Down
1 change: 1 addition & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ extern "C" {

//optional:
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
LLAMA_API void llama_rpc_init(const char * rpc_servers);

// Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void);
Expand Down

0 comments on commit 2ef868d

Please sign in to comment.