fix: #471 - action activity it not broadcasted to all clients

This commit is contained in:
jamesread
2025-11-05 09:08:04 +00:00
parent 2735793a44
commit 4e7273b616
2 changed files with 69 additions and 30 deletions

View File

@@ -132,9 +132,9 @@ func DefaultExecutor(cfg *config.Config) *Executor {
}
type listener interface {
OnExecutionStarted(logEntry *InternalLogEntry)
OnExecutionFinished(logEntry *InternalLogEntry)
OnOutputChunk(o []byte, executionTrackingId string)
OnExecutionStarted(logEntry *InternalLogEntry, action *config.Action)
OnExecutionFinished(logEntry *InternalLogEntry, action *config.Action)
OnOutputChunk(o []byte, executionTrackingId string, logEntry *InternalLogEntry, action *config.Action)
OnActionMapRebuilt()
}
@@ -509,13 +509,13 @@ func stepLogFinish(req *ExecutionRequest) bool {
func notifyListenersFinished(req *ExecutionRequest) {
for _, listener := range req.executor.listeners {
listener.OnExecutionFinished(req.logEntry)
listener.OnExecutionFinished(req.logEntry, req.Action)
}
}
func notifyListenersStarted(req *ExecutionRequest) {
for _, listener := range req.executor.listeners {
listener.OnExecutionStarted(req.logEntry)
listener.OnExecutionStarted(req.logEntry, req.Action)
}
}
@@ -532,7 +532,7 @@ type OutputStreamer struct {
func (ost *OutputStreamer) Write(o []byte) (n int, err error) {
for _, listener := range ost.Req.executor.listeners {
listener.OnOutputChunk(o, ost.Req.TrackingID)
listener.OnOutputChunk(o, ost.Req.TrackingID, ost.Req.logEntry, ost.Req.Action)
}
return ost.output.Write(o)

View File

@@ -6,6 +6,7 @@ import (
apiv1 "github.com/OliveTin/OliveTin/gen/grpc/olivetin/api/v1"
"github.com/OliveTin/OliveTin/internal/acl"
"github.com/OliveTin/OliveTin/internal/config"
"github.com/OliveTin/OliveTin/internal/executor"
ws "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
@@ -37,10 +38,16 @@ var ExecutionListener WebsocketExecutionListener
type WebsocketExecutionListener struct{}
func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry) {
broadcast(&apiv1.EventExecutionStarted{
func (WebsocketExecutionListener) OnExecutionStarted(ile *executor.InternalLogEntry, action *config.Action) {
evt := &apiv1.EventExecutionStarted{
LogEntry: internalLogEntryToPb(ile),
})
}
for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(evt))
}
}
}
func OnEntityChanged() {
@@ -68,7 +75,7 @@ func checkOriginPermissive(r *http.Request) bool {
return true
}
func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string) {
func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingId string, logEntry *executor.InternalLogEntry, action *config.Action) {
log.Tracef("outputchunk: %s", string(chunk))
oc := &apiv1.EventOutputChunk{
@@ -76,25 +83,49 @@ func (WebsocketExecutionListener) OnOutputChunk(chunk []byte, executionTrackingI
ExecutionTrackingId: executionTrackingId,
}
broadcast(oc)
for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(oc))
}
}
}
func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry) {
func (WebsocketExecutionListener) OnExecutionFinished(logEntry *executor.InternalLogEntry, action *config.Action) {
evt := &apiv1.EventExecutionFinished{
LogEntry: internalLogEntryToPb(logEntry),
}
log.Infof("WS Execution finished: %v", evt.LogEntry)
for client := range copyOfClients() {
if acl.IsAllowedLogs(cfg, client.authenticatedUser, action) {
writeMessageToClient(client, prepareMessage(evt))
}
}
broadcast(evt)
log.Infof("WS Execution finished: %v", evt.LogEntry)
}
func broadcast(pbmsg protoreflect.ProtoMessage) {
func copyOfClients() map[*WebsocketClient]struct{} {
sendmutex.Lock()
defer sendmutex.Unlock()
if clients == nil {
clients = make(map[*WebsocketClient]struct{})
}
copy := make(map[*WebsocketClient]struct{})
for client := range clients {
copy[client] = struct{}{}
}
return copy
}
func prepareMessage(pbmsg protoreflect.ProtoMessage) []byte {
payload, err := marshalOptions.Marshal(pbmsg)
if err != nil {
log.Errorf("websocket marshal error: %v", err)
return
return nil
}
messageType := pbmsg.ProtoReflect().Descriptor().FullName()
@@ -118,16 +149,26 @@ func broadcast(pbmsg protoreflect.ProtoMessage) {
hackyMessage = append(hackyMessage, []byte("}")...)
// </EVIL>
sendmutex.Lock()
if clients == nil {
clients = make(map[*WebsocketClient]struct{})
return hackyMessage
}
func broadcast(pbmsg protoreflect.ProtoMessage) {
message := prepareMessage(pbmsg)
for client := range copyOfClients() {
writeMessageToClient(client, message)
}
for client := range clients {
if err := client.conn.WriteMessage(ws.TextMessage, hackyMessage); err != nil {
log.Debugf("websocket send error: %v - closing connection", err)
_ = client.conn.Close()
delete(clients, client)
}
}
func writeMessageToClient(client *WebsocketClient, message []byte) {
sendmutex.Lock()
if err := client.conn.WriteMessage(ws.TextMessage, message); err != nil {
log.WithFields(log.Fields{
"error": err,
"client": client,
}).Debugf("websocket send error")
_ = client.conn.Close()
delete(clients, client)
}
sendmutex.Unlock()
}
@@ -148,16 +189,14 @@ func (c *WebsocketClient) messageLoop() {
func handleWebsocket(w http.ResponseWriter, r *http.Request) bool {
c, err := upgrader.Upgrade(w, r, nil)
unauthenticatedUser := authHttpRequest(r)
authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser)
if err != nil {
log.Warnf("Websocket issue: %v", err)
return false
}
// defer c.Close()
unauthenticatedUser := authHttpRequest(r)
authenticatedUser := acl.UserFromUnauthenticatedUser(cfg, unauthenticatedUser)
wsclient := &WebsocketClient{
conn: c,