diff --git a/service/internal/executor/executor.go b/service/internal/executor/executor.go index 33de003..62a8999 100644 --- a/service/internal/executor/executor.go +++ b/service/internal/executor/executor.go @@ -132,9 +132,9 @@ func DefaultExecutor(cfg *config.Config) *Executor { } type listener interface { - OnExecutionStarted(logEntry *InternalLogEntry) - OnExecutionFinished(logEntry *InternalLogEntry) - OnOutputChunk(o []byte, executionTrackingId string) + OnExecutionStarted(logEntry *InternalLogEntry, action *config.Action) + OnExecutionFinished(logEntry *InternalLogEntry, action *config.Action) + OnOutputChunk(o []byte, executionTrackingId string, logEntry *InternalLogEntry, action *config.Action) OnActionMapRebuilt() } @@ -509,13 +509,13 @@ func stepLogFinish(req *ExecutionRequest) bool { func notifyListenersFinished(req *ExecutionRequest) { for _, listener := range req.executor.listeners { - listener.OnExecutionFinished(req.logEntry) + listener.OnExecutionFinished(req.logEntry, req.Action) } } func notifyListenersStarted(req *ExecutionRequest) { for _, listener := range req.executor.listeners { - listener.OnExecutionStarted(req.logEntry) + listener.OnExecutionStarted(req.logEntry, req.Action) } } @@ -532,7 +532,7 @@ type OutputStreamer struct { func (ost *OutputStreamer) Write(o []byte) (n int, err error) { for _, listener := range ost.Req.executor.listeners { - listener.OnOutputChunk(o, ost.Req.TrackingID) + listener.OnOutputChunk(o, ost.Req.TrackingID, ost.Req.logEntry, ost.Req.Action) } return ost.output.Write(o) diff --git a/service/internal/httpservers/websocket.go b/service/internal/httpservers/websocket.go index 5fffe45..415bf12 100644 --- a/service/internal/httpservers/websocket.go +++ b/service/internal/httpservers/websocket.go @@ -6,6 +6,7 @@ import ( apiv1 "github.com/OliveTin/OliveTin/gen/grpc/olivetin/api/v1" "github.com/OliveTin/OliveTin/internal/acl" + "github.com/OliveTin/OliveTin/internal/config" "github.com/OliveTin/OliveTin/internal/executor" ws "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" @@ -37,10 +38,16 @@ var ExecutionListener WebsocketExecutionListener type WebsocketExecutionListener struct{} -func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry) { - broadcast(&apiv1.EventExecutionStarted{ +func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry, action *config.Action) { + evt := &apiv1.EventExecutionStarted{ LogEntry: internalLogEntryToPb(ile), - }) + } + + for client := range copyOfClients() { + if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) { + writeMessageToClient(client, prepareMessage(evt)) + } + } } func OnEntityChanged() { @@ -68,7 +75,7 @@ func checkOriginPermissive(r *http.Request) bool { return true } -func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string) { +func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string, logEntry *executor.InternalLogEntry, action *config.Action) { log.Tracef("outputchunk: %s", string(chunk)) oc := &apiv1.EventOutputChunk{ @@ -76,25 +83,49 @@ func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingI ExecutionTrackingId: executionTrackingId, } - broadcast(oc) + for client := range copyOfClients() { + if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) { + writeMessageToClient(client, prepareMessage(oc)) + } + } } -func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry) { +func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry, action *config.Action) { evt := &apiv1.EventExecutionFinished{ LogEntry: internalLogEntryToPb(logEntry), } - log.Infof("WS Execution finished: %v", evt.LogEntry) + for client := range copyOfClients() { + if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) { + writeMessageToClient(client, prepareMessage(evt)) + } + } - broadcast(evt) + log.Infof("WS Execution finished: %v", evt.LogEntry) } -func broadcast(pbmsg protoreflect.ProtoMessage) { +func copyOfClients() map[*WebsocketClient]struct{} { + sendmutex.Lock() + defer sendmutex.Unlock() + + if clients == nil { + clients = make(map[*WebsocketClient]struct{}) + } + + copy := make(map[*WebsocketClient]struct{}) + for client := range clients { + copy[client] = struct{}{} + } + + return copy +} + +func prepareMessage(pbmsg protoreflect.ProtoMessage) []byte { payload, err := marshalOptions.Marshal(pbmsg) if err != nil { log.Errorf("websocket marshal error: %v", err) - return + return nil } messageType := pbmsg.ProtoReflect().Descriptor().FullName() @@ -118,16 +149,26 @@ func broadcast(pbmsg protoreflect.ProtoMessage) { hackyMessage = append(hackyMessage, []byte("}")...) // - sendmutex.Lock() - if clients == nil { - clients = make(map[*WebsocketClient]struct{}) + return hackyMessage +} + +func broadcast(pbmsg protoreflect.ProtoMessage) { + message := prepareMessage(pbmsg) + + for client := range copyOfClients() { + writeMessageToClient(client, message) } - for client := range clients { - if err := client.conn.WriteMessage(ws.TextMessage, hackyMessage); err != nil { - log.Debugf("websocket send error: %v - closing connection", err) - _ = client.conn.Close() - delete(clients, client) - } +} + +func writeMessageToClient(client *WebsocketClient, message []byte) { + sendmutex.Lock() + if err := client.conn.WriteMessage(ws.TextMessage, message); err != nil { + log.WithFields(log.Fields{ + "error": err, + "client": client, + }).Debugf("websocket send error") + _ = client.conn.Close() + delete(clients, client) } sendmutex.Unlock() } @@ -148,16 +189,14 @@ func (c *WebsocketClient) messageLoop() { func handleWebsocket(w http.ResponseWriter, r *http.Request) bool { c, err := upgrader.Upgrade(w, r, nil) - unauthenticatedUser := authHttpRequest(r) - - authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser) - if err != nil { log.Warnf("Websocket issue: %v", err) return false } - // defer c.Close() + unauthenticatedUser := authHttpRequest(r) + + authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser) wsclient := &WebsocketClient{ conn: c,