fix: #471 - action activity it not broadcasted to all clients

This commit is contained in:
jamesread
2025-11-05 09:08:04 +00:00
parent 2735793a44
commit 4e7273b616
2 changed files with 69 additions and 30 deletions

View File

@@ -132,9 +132,9 @@ func DefaultExecutor(cfg *config.Config) *Executor {
} }
type listener interface { type listener interface {
OnExecutionStarted(logEntry *InternalLogEntry) OnExecutionStarted(logEntry *InternalLogEntry, action *config.Action)
OnExecutionFinished(logEntry *InternalLogEntry) OnExecutionFinished(logEntry *InternalLogEntry, action *config.Action)
OnOutputChunk(o []byte, executionTrackingId string) OnOutputChunk(o []byte, executionTrackingId string, logEntry *InternalLogEntry, action *config.Action)
OnActionMapRebuilt() OnActionMapRebuilt()
} }
@@ -509,13 +509,13 @@ func stepLogFinish(req *ExecutionRequest) bool {
func notifyListenersFinished(req *ExecutionRequest) { func notifyListenersFinished(req *ExecutionRequest) {
for _, listener := range req.executor.listeners { for _, listener := range req.executor.listeners {
listener.OnExecutionFinished(req.logEntry) listener.OnExecutionFinished(req.logEntry, req.Action)
} }
} }
func notifyListenersStarted(req *ExecutionRequest) { func notifyListenersStarted(req *ExecutionRequest) {
for _, listener := range req.executor.listeners { for _, listener := range req.executor.listeners {
listener.OnExecutionStarted(req.logEntry) listener.OnExecutionStarted(req.logEntry, req.Action)
} }
} }
@@ -532,7 +532,7 @@ type OutputStreamer struct {
func (ost *OutputStreamer) Write(o []byte) (n int, err error) { func (ost *OutputStreamer) Write(o []byte) (n int, err error) {
for _, listener := range ost.Req.executor.listeners { for _, listener := range ost.Req.executor.listeners {
listener.OnOutputChunk(o, ost.Req.TrackingID) listener.OnOutputChunk(o, ost.Req.TrackingID, ost.Req.logEntry, ost.Req.Action)
} }
return ost.output.Write(o) return ost.output.Write(o)

View File

@@ -6,6 +6,7 @@ import (
apiv1 "github.com/OliveTin/OliveTin/gen/grpc/olivetin/api/v1" apiv1 "github.com/OliveTin/OliveTin/gen/grpc/olivetin/api/v1"
"github.com/OliveTin/OliveTin/internal/acl" "github.com/OliveTin/OliveTin/internal/acl"
"github.com/OliveTin/OliveTin/internal/config"
"github.com/OliveTin/OliveTin/internal/executor" "github.com/OliveTin/OliveTin/internal/executor"
ws "github.com/gorilla/websocket" ws "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -37,10 +38,16 @@ var ExecutionListener WebsocketExecutionListener
type WebsocketExecutionListener struct{} type WebsocketExecutionListener struct{}
func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry) { func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry, action *config.Action) {
broadcast(&apiv1.EventExecutionStarted{ evt := &apiv1.EventExecutionStarted{
LogEntry: internalLogEntryToPb(ile), LogEntry: internalLogEntryToPb(ile),
}) }
for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(evt))
}
}
} }
func OnEntityChanged() { func OnEntityChanged() {
@@ -68,7 +75,7 @@ func checkOriginPermissive(r *http.Request) bool {
return true return true
} }
func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string) { func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string, logEntry *executor.InternalLogEntry, action *config.Action) {
log.Tracef("outputchunk: %s", string(chunk)) log.Tracef("outputchunk: %s", string(chunk))
oc := &apiv1.EventOutputChunk{ oc := &apiv1.EventOutputChunk{
@@ -76,25 +83,49 @@ func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingI
ExecutionTrackingId: executionTrackingId, ExecutionTrackingId: executionTrackingId,
} }
broadcast(oc) for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(oc))
}
}
} }
func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry) { func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry, action *config.Action) {
evt := &apiv1.EventExecutionFinished{ evt := &apiv1.EventExecutionFinished{
LogEntry: internalLogEntryToPb(logEntry), LogEntry: internalLogEntryToPb(logEntry),
} }
log.Infof("WS Execution finished: %v", evt.LogEntry) for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(evt))
}
}
broadcast(evt) log.Infof("WS Execution finished: %v", evt.LogEntry)
} }
func broadcast(pbmsg protoreflect.ProtoMessage) { func copyOfClients() map[*WebsocketClient]struct{} {
sendmutex.Lock()
defer sendmutex.Unlock()
if clients == nil {
clients = make(map[*WebsocketClient]struct{})
}
copy := make(map[*WebsocketClient]struct{})
for client := range clients {
copy[client] = struct{}{}
}
return copy
}
func prepareMessage(pbmsg protoreflect.ProtoMessage) []byte {
payload, err := marshalOptions.Marshal(pbmsg) payload, err := marshalOptions.Marshal(pbmsg)
if err != nil { if err != nil {
log.Errorf("websocket marshal error: %v", err) log.Errorf("websocket marshal error: %v", err)
return return nil
} }
messageType := pbmsg.ProtoReflect().Descriptor().FullName() messageType := pbmsg.ProtoReflect().Descriptor().FullName()
@@ -118,16 +149,26 @@ func broadcast(pbmsg protoreflect.ProtoMessage) {
hackyMessage = append(hackyMessage, []byte("}")...) hackyMessage = append(hackyMessage, []byte("}")...)
// </EVIL> // </EVIL>
sendmutex.Lock() return hackyMessage
if clients == nil { }
clients = make(map[*WebsocketClient]struct{})
func broadcast(pbmsg protoreflect.ProtoMessage) {
message := prepareMessage(pbmsg)
for client := range copyOfClients() {
writeMessageToClient(client, message)
} }
for client := range clients { }
if err := client.conn.WriteMessage(ws.TextMessage, hackyMessage); err != nil {
log.Debugf("websocket send error: %v - closing connection", err) func writeMessageToClient(client *WebsocketClient, message []byte) {
_ = client.conn.Close() sendmutex.Lock()
delete(clients, client) if err := client.conn.WriteMessage(ws.TextMessage, message); err != nil {
} log.WithFields(log.Fields{
"error": err,
"client": client,
}).Debugf("websocket send error")
_ = client.conn.Close()
delete(clients, client)
} }
sendmutex.Unlock() sendmutex.Unlock()
} }
@@ -148,16 +189,14 @@ func (c *WebsocketClient) messageLoop() {
func handleWebsocket(w http.ResponseWriter, r *http.Request) bool { func handleWebsocket(w http.ResponseWriter, r *http.Request) bool {
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
unauthenticatedUser := authHttpRequest(r)
authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser)
if err != nil { if err != nil {
log.Warnf("Websocket issue: %v", err) log.Warnf("Websocket issue: %v", err)
return false return false
} }
// defer c.Close() unauthenticatedUser := authHttpRequest(r)
authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser)
wsclient := &WebsocketClient{ wsclient := &WebsocketClient{
conn: c, conn: c,