package zwavejs import ( "context" "database/sql" "encoding/json" "errors" "fmt" "net" "sync" "time" "github.com/failsafe-go/failsafe-go" "github.com/failsafe-go/failsafe-go/retrypolicy" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "git.janky.solutions/finn/lockserver/db" ) type Client struct { Server string conn *websocket.Conn callbacks map[string]chan Result callbacksLock sync.Mutex } func New(server string) (*Client, error) { c := &Client{ Server: server, callbacks: make(map[string]chan Result), } return c, nil } func (c *Client) DialAndListen(ctx context.Context) { // Retry on ErrConnecting up to 3 times with a 1 second delay between attempts connectRetryPolicy := retrypolicy.Builder[*websocket.Conn](). WithBackoff(time.Second, time.Minute). Build() for { logrus.WithField("server", c.Server).Info("connecting to zwave-js server") conn, err := failsafe.Get(func() (*websocket.Conn, error) { conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.Server, nil) return conn, err }, connectRetryPolicy) if err != nil { logrus.WithError(err).Fatal("error connecting to zwavejs server") } c.conn = conn logrus.Info("connected to zwave-js server") if err := c.listen(ctx); err != nil { if errors.Is(err, net.ErrClosed) { return } logrus.WithError(err).Error("error communicating with zwavejs server") time.Sleep(time.Second * 5) continue } _ = c.conn.Close() } } func (c *Client) listen(ctx context.Context) error { for { var msg IncomingMessage // err := c.conn.ReadJSON(&msg) // if err != nil { // return err // } _, rawmsg, err := c.conn.ReadMessage() if err != nil { return err } if err = json.Unmarshal(rawmsg, &msg); err != nil { return err } logrus.WithField("type", msg.Type).Debug("received message from zwave-js") switch msg.Type { case "version": if err := c.conn.WriteJSON(OutgoingMessage{Command: CommandStartListening}); err != nil { return err } case "result": if msg.MessageID == "" { if err := syncState(ctx, *msg.Result); err != nil { return fmt.Errorf("error syncing state from zwavejs server: %v", err) } } else { c.handleCallback(msg.MessageID, *msg.Result) } case "event": if err := handleEvent(ctx, *msg.Event); err != nil { logrus.WithError(err).Error("error handling event") } default: logrus.WithField("type", msg.Type).Warn("received unexpected message type from zwave-js server") } } } func (c *Client) Shutdown() error { return c.conn.Close() } func syncState(ctx context.Context, result Result) error { queries, dbc, err := db.Get() if err != nil { return err } defer dbc.Close() for _, node := range result.State.Nodes { slots := make(map[int]db.LockCodeSlot) lockID := int64(-1) for _, value := range node.Values { if value.CommandClass != CommandClassUserCode { continue } slotNumber := value.PropertyKey.Int lockID = int64(node.NodeID) slot := slots[slotNumber] // check if there's an existing entry slot.Slot = int64(slotNumber) switch Property(value.PropertyName.String) { case PropertyUserCode: slot.Code = value.Value.String case PropertyUserIDStatus: slot.Enabled = value.Value.Int > 0 } slots[slotNumber] = slot } if len(slots) == 0 || lockID < 0 { continue } lock, err := queries.GetLockByDeviceID(ctx, lockID) if err != nil { if errors.Is(err, sql.ErrNoRows) { lock, err = queries.CreateLock(ctx, lockID) } if err != nil { return err } } for _, slot := range slots { err := queries.UpsertCodeSlot(ctx, db.UpsertCodeSlotParams{ Lock: lock.ID, Code: slot.Code, Slot: slot.Slot, Enabled: slot.Enabled, }) if err != nil { return fmt.Errorf("error upserting slot: %v", err) } } } return nil } func handleEvent(ctx context.Context, event Event) error { if event.Source != EventSourceNode || event.Event != EventTypeNotification { return nil } queries, dbc, err := db.Get() if err != nil { return err } defer dbc.Close() lock, err := queries.GetLockByDeviceID(ctx, int64(event.NodeID)) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil } return fmt.Errorf("error getting lock: %v", err) } code := sql.NullInt64{} if event.Parameters.UserID > 0 { slot, err := queries.GetLockCodeBySlot(ctx, db.GetLockCodeBySlotParams{ Lock: lock.ID, Slot: int64(event.Parameters.UserID), }) if err != nil { return fmt.Errorf("error getting code slot: %v", err) } code = sql.NullInt64{ Int64: slot.ID, Valid: true, } } err = queries.AddLogEntry(ctx, db.AddLogEntryParams{ Lock: lock.ID, Code: code, State: event.NotificationLabel, }) if err != nil { return fmt.Errorf("error adding log entry: %v", err) } logrus.WithFields(logrus.Fields{ "lock": lock.ID, "code": code, "state": event.NotificationLabel, }).Debug("processed lock event") return nil } func (c *Client) handleCallback(messageID string, result Result) error { c.callbacksLock.Lock() defer c.callbacksLock.Unlock() cb, ok := c.callbacks[messageID] if !ok { logrus.WithField("message_id", messageID).Warn("got response to a message we didn't send") return nil } // TODO: set a timeout cb <- result return nil } func (c *Client) sendMessage(message OutgoingMessageIface) (Result, error) { messageID := uuid.New().String() message.SetMessageID(messageID) ch := make(chan Result) c.callbacksLock.Lock() c.callbacks[messageID] = ch c.callbacksLock.Unlock() if err := c.conn.WriteJSON(message); err != nil { return Result{}, err } result := <-ch close(ch) c.callbacksLock.Lock() delete(c.callbacks, messageID) c.callbacksLock.Unlock() return result, nil } func (c *Client) SetNodeValue(ctx context.Context, nodeID int, valueID NodeValue, value AnyType) error { msg := NodeSetValueMessage{ OutgoingMessage: OutgoingMessage{Command: CommandNodeSetValue}, NodeID: nodeID, ValueID: valueID, Value: value, } result, err := c.sendMessage(&msg) if err != nil { return err } if !result.Success { return errors.New("non-successful response from zwave-js server") } return nil }