diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b134612..89175d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ * Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题 * 功能新增:支持上传图片和视觉模型 * 功能优化:优化聊天页面的复制代码按钮样式乱码 +* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图 +* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4 +* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。 +* 功能新增:支持管理后台 Logo 修改 ## 4.0.2 diff --git a/api/core/app_server.go b/api/core/app_server.go index 69645add..5c9d2ad6 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -218,7 +218,7 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/config/get" || c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/menu/list" || - c.Request.URL.Path == "/api/markMap/model" || + c.Request.URL.Path == "/api/markMap/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || diff --git a/api/core/types/web.go b/api/core/types/web.go index 601612fa..041a9859 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -21,7 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") - WsMjImg = WsMsgType("mj") + WsErr = WsMsgType("error") ) type BizCode int diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 08785752..dbb9f682 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -525,7 +525,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi request = request.WithContext(ctx) request.Header.Set("Content-Type", "application/json") - var proxyURL string if len(apiKey.ProxyURL) > 5 { // 使用代理 proxy, _ := url.Parse(apiKey.ProxyURL) client = &http.Client{ @@ -536,7 +535,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } else { client = http.DefaultClient } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model) + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) switch session.Model.Platform { case types.Azure: request.Header.Set("api-key", apiKey.Value) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index e4e57620..794eb284 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -1,26 +1,35 @@ package handler import ( + "bufio" + "bytes" "chatplus/core" "chatplus/core/types" + "chatplus/store/model" "chatplus/utils" - "github.com/gorilla/websocket" - "net/http" - + "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" ) // MarkMapHandler 生成思维导图 type MarkMapHandler struct { BaseHandler - clients *types.LMap[uint, *types.WsClient] + clients *types.LMap[int, *types.WsClient] } func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { return &MarkMapHandler{ BaseHandler: BaseHandler{App: app, DB: db}, - clients: types.NewLMap[uint, *types.WsClient](), + clients: types.NewLMap[int, *types.WsClient](), } } @@ -32,9 +41,13 @@ func (h *MarkMapHandler) Client(c *gin.Context) { } modelId := h.GetInt(c, "model_id", 0) - userId := h.GetLoginUserId(c) + userId := h.GetInt(c, "user_id", 0) logger.Info(modelId) + client := types.NewWsClient(ws) + if cli := h.clients.Get(userId); cli != nil { + cli.Close() + } // 保存会话连接 h.clients.Put(userId, client) @@ -55,12 +68,165 @@ func (h *MarkMapHandler) Client(c *gin.Context) { // 心跳消息 if message.Type == "heartbeat" { - logger.Debug("收到 Chat 心跳消息:", message.Content) + logger.Debug("收到 MarkMap 心跳消息:", message.Content) + continue + } + // change model + if message.Type == "model_id" { + modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) continue } logger.Info("Receive a message: ", message.Content) + err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) + if err != nil { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) + } } }() } + +func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { + var user model.User + res := h.DB.Model(&model.User{}).First(&user, userId) + if res.Error != nil { + return fmt.Errorf("error with query user info: %v", res.Error) + } + var chatModel model.ChatModel + res = h.DB.Where("id", modelId).First(&chatModel) + if res.Error != nil { + return fmt.Errorf("error with query chat model: %v", res.Error) + } + + if user.Status == false { + return errors.New("当前用户被禁用") + } + + if user.Power < chatModel.Power { + return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) + } + + messages := make([]interface{}, 0) + messages = append(messages, types.Message{Role: "system", Content: "你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。只输出 Markdown 内容,不要输出任何解释性的语句。"}) + messages = append(messages, types.Message{Role: "user", Content: prompt}) + var req = types.ApiRequest{ + Model: chatModel.Value, + Stream: true, + Messages: messages, + } + + var apiKey model.ApiKey + response, err := h.doRequest(req, chatModel, &apiKey) + if err != nil { + return fmt.Errorf("请求 OpenAI API 失败: %s", err) + } + + defer response.Body.Close() + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + // 循环读取 Chunk 消息 + var message = types.Message{} + scanner := bufio.NewScanner(response.Body) + var isNew = true + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 + return fmt.Errorf("error with decode data: %v", err) + } + + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + continue + } else if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + if isNew { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(client, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败: %v", err) + } + var res types.ApiError + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("解析响应失败: %v", err) + } + + // OpenAI API 调用异常处理 + if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { + // remove key + h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) + return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") + } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { + return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") + } else { + return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message) + } + } + + return nil +} + +func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + var res *gorm.DB + if chatModel.KeyId > 0 { + res = h.DB.Where("id", chatModel.KeyId).Find(apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + } + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + apiURL := apiKey.ApiURL + // 更新 API KEY 的最后使用时间 + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + // 创建 HttpClient 请求对象 + var client *http.Client + requestBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + if len(apiKey.ProxyURL) > 5 { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } else { + client = http.DefaultClient + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + return client.Do(request) +} diff --git a/api/main.go b/api/main.go index 4cc2d5c5..586b7f70 100644 --- a/api/main.go +++ b/api/main.go @@ -439,7 +439,7 @@ func main() { fx.Provide(handler.NewMarkMapHandler), fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { group := s.Engine.Group("/api/markMap/") - group.GET("model", h.GetModel) + group.Any("client", h.Client) }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { go func() { diff --git a/api/service/mj/plus_client.go b/api/service/mj/plus_client.go index 52846208..822d4b91 100644 --- a/api/service/mj/plus_client.go +++ b/api/service/mj/plus_client.go @@ -73,6 +73,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { // Blend 融图 func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) body := ImageReq{ BotType: "MID_JOURNEY", Dimensions: "SQUARE", @@ -163,7 +164,8 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := c.client.R(). @@ -189,7 +191,8 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := req.C().R(). diff --git a/database/update-v4.0.3.sql b/database/update-v4.0.3.sql index 219c4187..80250e73 100644 --- a/database/update-v4.0.3.sql +++ b/database/update-v4.0.3.sql @@ -1,2 +1,3 @@ ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`; -ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`; \ No newline at end of file +ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`; +INSERT INTO `chatgpt_plus`.`chatgpt_menus`(`id`, `name`, `icon`, `url`, `sort_num`, `enabled`) VALUES (12, '思维导图', '/images/menu/xmind.png', '/xmind', 3, 1); \ No newline at end of file diff --git a/web/src/assets/css/mark-map.styl b/web/src/assets/css/mark-map.styl index 943e33ae..d6d11f2a 100644 --- a/web/src/assets/css/mark-map.styl +++ b/web/src/assets/css/mark-map.styl @@ -66,10 +66,43 @@ .right-box { width 100% + h2 { color #ffffff } + .markdown { + color #ffffff + display flex + justify-content center + align-items center + + h1 { + color: #47fff1; + } + + h2 { + color: #ffcc00; + } + + ul { + list-style-type: disc; + margin-left: 20px; + + li { + line-height 1.5 + } + } + + strong { + font-weight: bold; + } + + em { + font-style: italic; + } + } + .body { display flex justify-content center diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 70bb4839..12952bf7 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -525,42 +525,10 @@