Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename outdated WaveAI types #1609

Merged
merged 4 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/app/store/services.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BlockServiceType {
SaveTerminalState(blockId: string, state: string, stateType: string, ptyOffset: number, termSize: TermSize): Promise<void> {
return WOS.callBackendService("block", "SaveTerminalState", Array.from(arguments))
}
SaveWaveAiData(arg2: string, arg3: OpenAIPromptMessageType[]): Promise<void> {
SaveWaveAiData(arg2: string, arg3: WaveAIPromptMessageType[]): Promise<void> {
return WOS.callBackendService("block", "SaveWaveAiData", Array.from(arguments))
}
}
Expand Down
2 changes: 1 addition & 1 deletion frontend/app/store/wshclientapi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class RpcApiType {
}

// command "streamwaveai" [responsestream]
StreamWaveAiCommand(client: WshClient, data: OpenAiStreamRequest, opts?: RpcOpts): AsyncGenerator<OpenAIPacketType, void, boolean> {
StreamWaveAiCommand(client: WshClient, data: WaveAIStreamRequest, opts?: RpcOpts): AsyncGenerator<WaveAIPacketType, void, boolean> {
return client.wshRpcStream("streamwaveai", data, opts);
}

Expand Down
18 changes: 9 additions & 9 deletions frontend/app/view/waveai/waveai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ interface ChatItemProps {
model: WaveAiModel;
}

function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType {
function promptToMsg(prompt: WaveAIPromptMessageType): ChatMessageType {
return {
id: crypto.randomUUID(),
user: prompt.role,
Expand Down Expand Up @@ -67,7 +67,7 @@ export class WaveAiModel implements ViewModel {
blockAtom: Atom<Block>;
presetKey: Atom<string>;
presetMap: Atom<{ [k: string]: MetaType }>;
aiOpts: Atom<OpenAIOptsType>;
aiOpts: Atom<WaveAIOptsType>;
viewIcon?: Atom<string | IconButtonDecl>;
viewName?: Atom<string>;
viewText?: Atom<string | HeaderElem[]>;
Expand Down Expand Up @@ -167,7 +167,7 @@ export class WaveAiModel implements ViewModel {
...settings,
...meta,
};
const opts: OpenAIOptsType = {
const opts: WaveAIOptsType = {
model: settings["ai:model"] ?? null,
apitype: settings["ai:apitype"] ?? null,
orgid: settings["ai:orgid"] ?? null,
Expand Down Expand Up @@ -293,12 +293,12 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.messagesAtom, history.map(promptToMsg));
}

async fetchAiData(): Promise<Array<OpenAIPromptMessageType>> {
async fetchAiData(): Promise<Array<WaveAIPromptMessageType>> {
const { data } = await fetchWaveFile(this.blockId, "aidata");
if (!data) {
return [];
}
const history: Array<OpenAIPromptMessageType> = JSON.parse(new TextDecoder().decode(data));
const history: Array<WaveAIPromptMessageType> = JSON.parse(new TextDecoder().decode(data));
return history.slice(Math.max(history.length - slidingWindowSize, 0));
}

Expand Down Expand Up @@ -333,7 +333,7 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.addMessageAtom, newMessage);
// send message to backend and get response
const opts = globalStore.get(this.aiOpts);
const newPrompt: OpenAIPromptMessageType = {
const newPrompt: WaveAIPromptMessageType = {
role: "user",
content: text,
};
Expand Down Expand Up @@ -368,7 +368,7 @@ export class WaveAiModel implements ViewModel {
// only save the author's prompt
await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt]);
} else {
const responsePrompt: OpenAIPromptMessageType = {
const responsePrompt: WaveAIPromptMessageType = {
role: "assistant",
content: fullMsg,
};
Expand All @@ -383,7 +383,7 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.removeLastMessageAtom);
} else {
globalStore.set(this.updateLastMessageAtom, "", false);
const responsePrompt: OpenAIPromptMessageType = {
const responsePrompt: WaveAIPromptMessageType = {
role: "assistant",
content: fullMsg,
};
Expand All @@ -397,7 +397,7 @@ export class WaveAiModel implements ViewModel {
};
globalStore.set(this.addMessageAtom, errorMessage);
globalStore.set(this.updateLastMessageAtom, "", false);
const errorPrompt: OpenAIPromptMessageType = {
const errorPrompt: WaveAIPromptMessageType = {
role: "error",
content: errMsg,
};
Expand Down
92 changes: 46 additions & 46 deletions frontend/types/gotypes.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -519,52 +519,6 @@ declare global {
// waveobj.ORef
type ORef = string;

// wshrpc.OpenAIOptsType
type OpenAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeoutms?: number;
};

// wshrpc.OpenAIPacketType
type OpenAIPacketType = {
type: string;
model?: string;
created?: number;
finish_reason?: string;
usage?: OpenAIUsageType;
index?: number;
text?: string;
error?: string;
};

// wshrpc.OpenAIPromptMessageType
type OpenAIPromptMessageType = {
role: string;
content: string;
name?: string;
};

// wshrpc.OpenAIUsageType
type OpenAIUsageType = {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};

// wshrpc.OpenAiStreamRequest
type OpenAiStreamRequest = {
clientid?: string;
opts: OpenAIOptsType;
prompt: OpenAIPromptMessageType[];
};

// wshrpc.PathCommandData
type PathCommandData = {
pathtype: string;
Expand Down Expand Up @@ -1016,6 +970,52 @@ declare global {
fullconfig: FullConfigType;
};

// wshrpc.WaveAIOptsType
type WaveAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeoutms?: number;
};

// wshrpc.WaveAIPacketType
type WaveAIPacketType = {
type: string;
model?: string;
created?: number;
finish_reason?: string;
usage?: WaveAIUsageType;
index?: number;
text?: string;
error?: string;
};

// wshrpc.WaveAIPromptMessageType
type WaveAIPromptMessageType = {
role: string;
content: string;
name?: string;
};

// wshrpc.WaveAIStreamRequest
type WaveAIStreamRequest = {
clientid?: string;
opts: WaveAIOptsType;
prompt: WaveAIPromptMessageType[];
};

// wshrpc.WaveAIUsageType
type WaveAIUsageType = {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};

// wps.WaveEvent
type WaveEvent = {
event: string;
Expand Down
2 changes: 1 addition & 1 deletion pkg/service/blockservice/blockservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (bs *BlockService) SaveTerminalState(ctx context.Context, blockId string, s
return nil
}

func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.OpenAIPromptMessageType) error {
func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.WaveAIPromptMessageType) error {
block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil {
return err
Expand Down
28 changes: 14 additions & 14 deletions pkg/waveai/anthropicbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ func parseSSE(reader *bufio.Reader) (*sseEvent, error) {
}
}

func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])

go func() {
defer func() {
Expand Down Expand Up @@ -231,23 +231,23 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope
switch sse.Event {
case "message_start":
if event.Message != nil {
pk := MakeOpenAIPacket()
pk := MakeWaveAIPacket()
pk.Model = event.Message.Model
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}

case "content_block_start":
if event.ContentBlock != nil && event.ContentBlock.Text != "" {
pk := MakeOpenAIPacket()
pk := MakeWaveAIPacket()
pk.Text = event.ContentBlock.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}

case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
pk := MakeOpenAIPacket()
pk := MakeWaveAIPacket()
pk.Text = event.Delta.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}

case "content_block_stop":
Expand All @@ -258,27 +258,27 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope
case "message_delta":
// Update message metadata, usage stats
if event.Usage != nil {
pk := MakeOpenAIPacket()
pk.Usage = &wshrpc.OpenAIUsageType{
pk := MakeWaveAIPacket()
pk.Usage = &wshrpc.WaveAIUsageType{
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens,
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}

case "message_stop":
if event.Message != nil {
pk := MakeOpenAIPacket()
pk := MakeWaveAIPacket()
pk.FinishReason = event.Message.StopReason
if event.Message.Usage != nil {
pk.Usage = &wshrpc.OpenAIUsageType{
pk.Usage = &wshrpc.WaveAIUsageType{
PromptTokens: event.Message.Usage.InputTokens,
CompletionTokens: event.Message.Usage.OutputTokens,
TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens,
}
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
}

default:
Expand Down
20 changes: 10 additions & 10 deletions pkg/waveai/cloudbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ type WaveAICloudBackend struct{}

var _ AIBackend = WaveAICloudBackend{}

type OpenAICloudReqPacketType struct {
type WaveAICloudReqPacketType struct {
Type string `json:"type"`
ClientId string `json:"clientid"`
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"`
Prompt []wshrpc.WaveAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}

func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
func MakeWaveAICloudReqPacket() *WaveAICloudReqPacketType {
return &WaveAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}

func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint()
go func() {
defer func() {
Expand Down Expand Up @@ -69,14 +69,14 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
}
}()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
var sendablePromptMsgs []wshrpc.WaveAIPromptMessageType
for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" {
continue
}
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
}
reqPk := MakeOpenAICloudReqPacket()
reqPk := MakeWaveAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens
Expand All @@ -101,7 +101,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
break
}
var streamResp *wshrpc.OpenAIPacketType
var streamResp *wshrpc.WaveAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
Expand All @@ -115,7 +115,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
break
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *streamResp}
}
}()
return rtn
Expand Down
Loading
Loading