Skip to content

Commit

Permalink
Merge pull request #1060 from iorisa/fixbug/issues/1059
Browse files Browse the repository at this point in the history
fixbug: llm.timeout not working
  • Loading branch information
geekan authored Mar 21, 2024
2 parents 0958cc3 + c6b9e23 commit 0e2a578
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 70 deletions.
1 change: 1 addition & 0 deletions config/config2.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ llm:
api_key: "YOUR_API_KEY"
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
proxy: "YOUR_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
# Azure-exclusive pricing plan mappings:
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
Expand Down
9 changes: 6 additions & 3 deletions metagpt/actions/action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential

from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
Expand Down Expand Up @@ -416,7 +417,7 @@ async def _aask_v1(
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
Expand Down Expand Up @@ -448,7 +449,9 @@ def set_llm(self, llm):
def set_context(self, context):
self.set_recursive("context", context)

async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
async def simple_fill(
self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None
):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)

if schema != "raw":
Expand All @@ -473,7 +476,7 @@ async def fill(
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
):
"""Fill the node(s) with mode.
Expand Down
8 changes: 7 additions & 1 deletion metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pydantic import field_validator

from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel


Expand Down Expand Up @@ -74,7 +75,7 @@ class LLMConfig(YamlModel):
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
top_logprobs: Optional[int] = None
timeout: int = 60
timeout: int = 600

# For Network
proxy: Optional[str] = None
Expand All @@ -88,3 +89,8 @@ def check_llm_key(cls, v):
if v in ["", None, "YOUR_API_KEY"]:
raise ValueError("Please set your API key in config2.yaml")
return v

@field_validator("timeout")
@classmethod
def check_timeout(cls, v):
return v or LLM_API_TIMEOUT
5 changes: 4 additions & 1 deletion metagpt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def get_metagpt_root():

# REDIS
REDIS_KEY = "REDIS_KEY"
LLM_API_TIMEOUT = 300

# Message id
IGNORED_MESSAGE_ID = "0"
Expand All @@ -132,3 +131,7 @@ def get_metagpt_root():
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"

# Timeout
USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration.
LLM_API_TIMEOUT = 300
9 changes: 5 additions & 4 deletions metagpt/provider/anthropic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from anthropic.types import Message, Usage

from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
Expand Down Expand Up @@ -41,15 +42,15 @@ def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool
def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text

async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
self._update_costs(resp.usage, self.model)
return resp

async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = Usage(input_tokens=0, output_tokens=0)
Expand Down
28 changes: 17 additions & 11 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
Expand Down Expand Up @@ -130,7 +131,7 @@ async def aask(
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
Expand All @@ -146,31 +147,31 @@ async def aask(
else:
message.extend(msg)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp

def _extract_assistant_rsp(self, context):
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])

async def aask_batch(self, msgs: list, timeout=3) -> str:
async def aask_batch(self, msgs: list, timeout=USE_CONFIG_TIMEOUT) -> str:
"""Sequential questioning"""
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
rsp_text = await self.acompletion_text(context, timeout=timeout)
rsp_text = await self.acompletion_text(context, timeout=self.get_timeout(timeout))
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)

async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
raise NotImplementedError

@abstractmethod
async def _achat_completion(self, messages: list[dict], timeout=3):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""

@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""Asynchronous version of completion
All GPTAPIs are required to provide the standard OpenAI completion interface
[
Expand All @@ -181,7 +182,7 @@ async def acompletion(self, messages: list[dict], timeout=3):
"""

@abstractmethod
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""

@retry(
Expand All @@ -191,11 +192,13 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3)
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
async def acompletion_text(
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
resp = await self._achat_completion(messages, timeout=timeout)
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(resp)

def get_choice_text(self, rsp: dict) -> str:
Expand Down Expand Up @@ -258,3 +261,6 @@ def with_model(self, model: str):
"""Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
self.config.model = model
return self

def get_timeout(self, timeout: int) -> int:
return timeout or self.config.timeout or LLM_API_TIMEOUT
9 changes: 5 additions & 4 deletions metagpt/provider/dashscope_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
UnsupportedApiProtocol,
)

from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
Expand Down Expand Up @@ -202,16 +203,16 @@ def completion(self, messages: list[dict]) -> GenerationOutput:
self._update_costs(dict(resp.usage))
return resp.output

async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output

async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
Expand Down
2 changes: 1 addition & 1 deletion metagpt/provider/general_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ async def arequest_raw(
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
timeout = aiohttp.ClientTimeout(total=request_timeout or TIMEOUT_SECS)

if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
Expand Down
11 changes: 7 additions & 4 deletions metagpt/provider/google_gemini_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
Expand Down Expand Up @@ -123,16 +124,18 @@ def completion(self, messages: list[dict]) -> "GenerateContentResponse":
self._update_costs(usage)
return resp

async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
async def _achat_completion(
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT
) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
usage = await self.aget_usage(messages, resp.text)
self._update_costs(usage)
return resp

async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)
Expand Down
15 changes: 8 additions & 7 deletions metagpt/provider/human_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

from metagpt.configs.llm_config import LLMConfig
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM

Expand All @@ -18,7 +19,7 @@ class HumanProvider(BaseLLM):
def __init__(self, config: LLMConfig):
pass

def ask(self, msg: str, timeout=3) -> str:
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)
if rsp in ["exit", "quit"]:
Expand All @@ -31,20 +32,20 @@ async def aask(
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
) -> str:
return self.ask(msg, timeout=timeout)
return self.ask(msg, timeout=self.get_timeout(timeout))

async def _achat_completion(self, messages: list[dict], timeout=3):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass

async def acompletion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""dummy implementation of abstract method in base"""
return []

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass

async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""dummy implementation of abstract method in base"""
return ""
14 changes: 7 additions & 7 deletions metagpt/provider/ollama_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json

from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
Expand Down Expand Up @@ -50,28 +50,28 @@ def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
chunk = chunk.decode(encoding)
return json.loads(chunk)

async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
request_timeout=LLM_API_TIMEOUT,
request_timeout=self.get_timeout(timeout),
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp

async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
stream=True,
params=self._const_kwargs(messages, stream=True),
request_timeout=LLM_API_TIMEOUT,
request_timeout=self.get_timeout(timeout),
)

collected_content = []
Expand Down
Loading

0 comments on commit 0e2a578

Please sign in to comment.