diff --git a/ascii_transport_test.go b/ascii_transport_test.go new file mode 100644 index 0000000..19c8cea --- /dev/null +++ b/ascii_transport_test.go @@ -0,0 +1,274 @@ +//go:build linux || freebsd || openbsd || netbsd +// +build linux freebsd openbsd netbsd + +// Copyright 2014 Quoc-Viet Nguyen. All rights reserved. +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +package modbus + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "path/filepath" + "strings" + "testing" + "time" + + serialpkg "github.com/grid-x/serial" +) + +func TestASCIISerialTransporter_Send_PTY(t *testing.T) { + master, slavePath, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master.Close() + + // Request: 01 03 00 00 00 01 (Read Holding Registers) + // ASCII: :010300000001FB\r\n + reqASCII := []byte(":010300000001FB\r\n") + + // Response: 01 03 02 00 00 + // ASCII: :0103020000FA\r\n + respASCII := []byte(":0103020000FA\r\n") + + transporter := &asciiSerialTransporter{} + transporter.Address = slavePath + transporter.BaudRate = 19200 + transporter.Timeout = 1 * time.Second + transporter.IdleTimeout = serialIdleTimeout + + // Start a goroutine to read request and write response to master + go func() { + buf := make([]byte, 1024) + n, err := master.Read(buf) + if err != nil { + return + } + if !bytes.Equal(buf[:n], reqASCII) { + // t.Errorf would be racy here, just log or ignore + return + } + // Write response + _, err = master.Write(respASCII) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } + }() + + ctx := context.Background() + aduResponse, err := transporter.Send(ctx, reqASCII) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + if !bytes.Equal(aduResponse, respASCII) { + t.Errorf("Expected response %s, got %s", respASCII, aduResponse) + } +} + +func TestASCIISerialTransporter_Timeout_PTY(t *testing.T) { + master, slavePath, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master.Close() + + reqASCII := []byte(":010300000001FB\r\n") + + transporter := &asciiSerialTransporter{} + transporter.Address = slavePath + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.IdleTimeout = serialIdleTimeout + + // Don't write anything to master + + ctx := context.Background() + _, err = transporter.Send(ctx, reqASCII) + if err == nil { + t.Fatal("Expected timeout error, got nil") + } +} + +func TestASCIISerialTransporter_ReconnectOnMidCommunicationEOF_PTY(t *testing.T) { + master, slavePath, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master.Close() + + reqASCII := []byte(":010300000001FB\r\n") + respASCII := []byte(":0103020000FA\r\n") + partialResp := respASCII[:len(respASCII)-2] + var logs bytes.Buffer + + transporter := &asciiSerialTransporter{} + transporter.Address = slavePath + transporter.BaudRate = 19200 + transporter.Timeout = 200 * time.Millisecond + transporter.IdleTimeout = serialIdleTimeout + transporter.LinkRecoveryTimeout = 200 * time.Millisecond + transporter.Logger = log.New(&logs, "", 0) + + serverDone := make(chan error, 1) + + go func() { + buf := make([]byte, 1024) + n, err := master.Read(buf) + if err != nil { + serverDone <- fmt.Errorf("failed to read initial request: %w", err) + return + } + if !bytes.Equal(buf[:n], reqASCII) { + serverDone <- fmt.Errorf("unexpected initial request: got %q want %q", buf[:n], reqASCII) + return + } + if _, err := master.Write(partialResp); err != nil { + serverDone <- fmt.Errorf("failed to write partial response: %w", err) + return + } + if err := master.Close(); err != nil { + serverDone <- fmt.Errorf("failed to close initial PTY master: %w", err) + return + } + serverDone <- nil + }() + + _, err = transporter.Send(context.Background(), reqASCII) + if err == nil { + t.Fatal("expected Send to fail after reconnect attempt, got nil") + } + + if !strings.Contains(logs.String(), "connection reset, reconnecting") { + t.Fatalf("expected reconnect log entry, got %q", logs.String()) + } + + select { + case err := <-serverDone: + if err != nil { + t.Fatal(err) + } + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for PTY server") + } + + if !strings.Contains(err.Error(), "could not open") && !strings.Contains(err.Error(), "link recovery timeout reached") { + t.Fatalf("expected reconnect-related error, got %v", err) + } +} + +func TestASCIISerialTransporter_PartialResponseThenTimeout(t *testing.T) { + reqASCII := []byte(":010300000001FB\r\n") + respASCII := []byte(":0103020000FA\r\n") + + port := &scriptedPort{ + readData: respASCII[:len(respASCII)-2], + readErr: serialpkg.ErrTimeout, + } + + transporter := &asciiSerialTransporter{} + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + + _, err := transporter.Send(context.Background(), reqASCII) + if !errors.Is(err, serialpkg.ErrTimeout) { + t.Fatalf("expected timeout after partial response, got %v", err) + } + if got := port.written.Bytes(); !bytes.Equal(got, reqASCII) { + t.Fatalf("expected request %q, got %q", reqASCII, got) + } +} + +func TestASCIISerialTransporter_RecoveryDisabledOnReadEOF(t *testing.T) { + reqASCII := []byte(":010300000001FB\r\n") + port := &scriptedPort{readErr: io.EOF} + + transporter := &asciiSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = 0 + + _, err := transporter.Send(context.Background(), reqASCII) + if err == nil { + t.Fatal("expected link recovery timeout error, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") || !errors.Is(err, io.EOF) { + t.Fatalf("expected link recovery timeout wrapping EOF, got %v", err) + } + if got := port.written.Bytes(); !bytes.Equal(got, reqASCII) { + t.Fatalf("expected request %q, got %q", reqASCII, got) + } +} + +func TestASCIISerialTransporter_ReconnectBudgetExhaustedOnReadEOF(t *testing.T) { + reqASCII := []byte(":010300000001FB\r\n") + port := &scriptedPort{readErr: io.EOF} + recoveryTimeout := 80 * time.Millisecond + + transporter := &asciiSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = recoveryTimeout + + start := time.Now() + _, err := transporter.Send(context.Background(), reqASCII) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected link recovery timeout error, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") { + t.Fatalf("expected link recovery timeout error, got %v", err) + } + if !strings.Contains(err.Error(), "could not open") { + t.Fatalf("expected reconnect open failure to be wrapped, got %v", err) + } + if elapsed < recoveryTimeout-20*time.Millisecond { + t.Fatalf("expected recovery to keep retrying for about %v, returned after %v", recoveryTimeout, elapsed) + } + if !port.closed { + t.Fatal("expected reconnect to close the original port after read EOF") + } + if got := port.written.Bytes(); !bytes.Equal(got, reqASCII) { + t.Fatalf("expected request %q, got %q", reqASCII, got) + } +} + +func TestASCIISerialTransporter_ReconnectOnWriteEOF(t *testing.T) { + reqASCII := []byte(":010300000001FB\r\n") + port := &scriptedPort{writeErr: io.EOF} + recoveryTimeout := 80 * time.Millisecond + + transporter := &asciiSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = recoveryTimeout + + start := time.Now() + _, err := transporter.Send(context.Background(), reqASCII) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected reconnect error after write EOF, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") || !strings.Contains(err.Error(), "could not open") { + t.Fatalf("expected timed-out reconnect open failure, got %v", err) + } + if elapsed < recoveryTimeout-20*time.Millisecond { + t.Fatalf("expected recovery to keep retrying for about %v, returned after %v", recoveryTimeout, elapsed) + } + if !port.closed { + t.Fatal("expected reconnect to close the original port after write EOF") + } +} diff --git a/asciiclient.go b/asciiclient.go index a310372..3aa6a0d 100644 --- a/asciiclient.go +++ b/asciiclient.go @@ -9,6 +9,7 @@ import ( "context" "encoding/hex" "fmt" + "io" "time" ) @@ -35,6 +36,7 @@ func NewASCIIClientHandler(address string) *ASCIIClientHandler { handler.Address = address handler.Timeout = serialTimeout handler.IdleTimeout = serialIdleTimeout + handler.ReconnectRetryInterval = serialReconnectRetryInterval return handler } @@ -181,17 +183,54 @@ func (mb *asciiSerialTransporter) Send(ctx context.Context, aduRequest []byte) ( mb.lastActivity = time.Now() mb.startCloseTimer() - // Send the request - mb.logf("modbus: send % x\n", aduRequest) - if _, err = mb.port.Write(aduRequest); err != nil { + linkRecoveryDeadline := time.Now().Add(mb.LinkRecoveryTimeout) + + for { + // Send the request + mb.logf("modbus: send % x\n", aduRequest) + if _, err = mb.port.Write(aduRequest); err != nil { + if mb.shouldRecover(err) { + if err = mb.reconnect(ctx, err, linkRecoveryDeadline); err != nil { + return + } + continue + } + + return + } + // Get the response + connDeadline := time.Now().Add(mb.Timeout) + aduResponse, err = readASCII(mb.port, connDeadline) + if aduResponse != nil { + mb.logf("modbus: recv % x\n", aduResponse[:]) + } + if err != nil { + if mb.shouldRecover(err) { + if err = mb.reconnect(ctx, err, linkRecoveryDeadline); err != nil { + return + } + continue + } + // Unknown error + mb.logf("modbus: read error: %v", err) + return + } + return } - // Get the response +} + +func readASCII(r io.Reader, deadline time.Time) ([]byte, error) { var n, length int var data [asciiMaxSize]byte + var err error + for { - if n, err = mb.port.Read(data[length:]); err != nil { - return + if time.Now().After(deadline) { + return nil, fmt.Errorf("failed to read from serial port before deadline: %w", context.DeadlineExceeded) + } + if n, err = r.Read(data[length:]); err != nil { + return nil, err } length += n if length >= asciiMaxSize || n == 0 { @@ -204,9 +243,8 @@ func (mb *asciiSerialTransporter) Send(ctx context.Context, aduRequest []byte) ( } } } - aduResponse = data[:length] - mb.logf("modbus: recv % x\n", aduResponse) - return + + return data[:length], nil } // writeHex encodes byte to string in hexadecimal, e.g. 0xA5 => "A5" diff --git a/asciiclient_test.go b/asciiclient_test.go index 1502066..006cadd 100644 --- a/asciiclient_test.go +++ b/asciiclient_test.go @@ -6,9 +6,31 @@ package modbus import ( "bytes" + "context" + "errors" + "io" + "strings" "testing" + "time" ) +type asciiTestReader struct { + readData []byte + readErr error +} + +func (r *asciiTestReader) Read(b []byte) (int, error) { + if len(r.readData) > 0 { + b[0] = r.readData[0] + r.readData = r.readData[1:] + return 1, nil + } + if r.readErr != nil { + return 0, r.readErr + } + return 0, io.EOF +} + func TestASCIIEncoding(t *testing.T) { encoder := asciiPackager{} encoder.SlaveID = 17 @@ -68,6 +90,121 @@ func TestASCIIDecodeStartCharacter(t *testing.T) { } } +func TestASCIIDecodeInvalidLRC(t *testing.T) { + decoder := asciiPackager{} + adu := []byte(":F7031389000A61\r\n") + + _, err := decoder.Decode(adu) + if err == nil { + t.Fatal("expected invalid LRC error, got nil") + } + if !strings.Contains(err.Error(), "response lrc") { + t.Fatalf("expected LRC mismatch error, got %v", err) + } +} + +func TestASCIIVerifyErrors(t *testing.T) { + decoder := asciiPackager{} + aduReq := []byte(":010300010002F9\r\n") + + testcases := []struct { + name string + aduResp []byte + expectedErr string + }{ + { + name: "too short", + aduResp: []byte(":01\r\n"), + expectedErr: "does not meet minimum", + }, + { + name: "odd payload length", + aduResp: []byte(":010304010F1509CA\r"), + expectedErr: "is not an even number", + }, + { + name: "invalid start character", + aduResp: []byte("!010304010F1509CA\r\n"), + expectedErr: "is not started", + }, + { + name: "invalid frame terminator", + aduResp: []byte(":010304010F1509CA\n\n"), + expectedErr: "is not ended", + }, + { + name: "slave mismatch", + aduResp: []byte(":020304010F1509CA\r\n"), + expectedErr: "does not match request", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := decoder.Verify(aduReq, tc.aduResp) + if err == nil { + t.Fatal("expected Verify to fail, got nil") + } + if !strings.Contains(err.Error(), tc.expectedErr) { + t.Fatalf("expected error containing %q, got %v", tc.expectedErr, err) + } + }) + } +} + +func TestReadASCII(t *testing.T) { + testcases := []struct { + name string + reader *asciiTestReader + deadline time.Time + want []byte + wantErr error + }{ + { + name: "complete frame", + reader: &asciiTestReader{readData: []byte(":0103020000FA\r\n")}, + deadline: time.Now().Add(time.Second), + want: []byte(":0103020000FA\r\n"), + }, + { + name: "stops at terminator", + reader: &asciiTestReader{readData: []byte(":0103020000FA\r\nignored")}, + deadline: time.Now().Add(time.Second), + want: []byte(":0103020000FA\r\n"), + }, + { + name: "reader timeout after partial frame", + reader: &asciiTestReader{readData: []byte(":0103020000FA\r"), readErr: context.DeadlineExceeded}, + deadline: time.Now().Add(time.Second), + wantErr: context.DeadlineExceeded, + }, + { + name: "deadline exceeded before read", + reader: &asciiTestReader{readData: []byte(":0103020000FA\r\n")}, + deadline: time.Now().Add(-time.Millisecond), + wantErr: context.DeadlineExceeded, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + got, err := readASCII(tc.reader, tc.deadline) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected error %v, got %v", tc.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("readASCII returned error: %v", err) + } + if !bytes.Equal(got, tc.want) { + t.Fatalf("expected %q, got %q", tc.want, got) + } + }) + } +} + func BenchmarkASCIIEncoder(b *testing.B) { encoder := asciiPackager{ SlaveID: 10, diff --git a/rtu_transport_test.go b/rtu_transport_test.go new file mode 100644 index 0000000..02d480b --- /dev/null +++ b/rtu_transport_test.go @@ -0,0 +1,346 @@ +//go:build linux || freebsd || openbsd || netbsd +// +build linux freebsd openbsd netbsd + +// Copyright 2014 Quoc-Viet Nguyen. All rights reserved. +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +package modbus + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "testing" + "time" + "unsafe" + + serialpkg "github.com/grid-x/serial" +) + +// openPTY opens a PTY pair and returns the master file and the slave path. +func openPTY() (master *os.File, slavePath string, err error) { + master, err = os.OpenFile("/dev/ptmx", os.O_RDWR, 0) + if err != nil { + return nil, "", err + } + + // unlockpt + var unlock int32 + // TIOCSPTLCK + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, master.Fd(), syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&unlock))) + if errno != 0 { + master.Close() + return nil, "", errno + } + + // ptsname + var ptyno int32 + // TIOCGPTN + _, _, errno = syscall.Syscall(syscall.SYS_IOCTL, master.Fd(), syscall.TIOCGPTN, uintptr(unsafe.Pointer(&ptyno))) + if errno != 0 { + master.Close() + return nil, "", errno + } + + slavePath = fmt.Sprintf("/dev/pts/%d", ptyno) + return master, slavePath, nil +} + +type scriptedPort struct { + mu sync.Mutex + readData []byte + readErr error + writeErr error + written bytes.Buffer + closed bool +} + +func (p *scriptedPort) Read(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.readData) > 0 { + b[0] = p.readData[0] + p.readData = p.readData[1:] + return 1, nil + } + if p.readErr != nil { + return 0, p.readErr + } + return 0, io.EOF +} + +func (p *scriptedPort) Write(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.writeErr != nil { + return 0, p.writeErr + } + return p.written.Write(b) +} + +func (p *scriptedPort) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + return nil +} + +func TestRTUSerialTransporter_Send_PTY(t *testing.T) { + master, slavePath, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master.Close() + + // Request: 01 03 00 00 00 01 84 0A (Read Holding Registers) + // Response: 01 03 02 00 00 B8 44 + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + resp := []byte{0x01, 0x03, 0x02, 0x00, 0x00, 0xB8, 0x44} + + transporter := &rtuSerialTransporter{} + transporter.Address = slavePath + transporter.BaudRate = 19200 + transporter.Timeout = 1 * time.Second + + // Start a goroutine to read request and write response to master + go func() { + buf := make([]byte, 1024) + n, err := master.Read(buf) + if err != nil { + return + } + if !bytes.Equal(buf[:n], req) { + // t.Errorf would be racy here, just log or ignore + return + } + // Write response + _, err = master.Write(resp) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } + }() + + ctx := context.Background() + aduResponse, err := transporter.Send(ctx, req) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + if !bytes.Equal(aduResponse, resp) { + t.Errorf("Expected response %x, got %x", resp, aduResponse) + } +} + +func TestRTUSerialTransporter_Timeout_PTY(t *testing.T) { + master, slavePath, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master.Close() + + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + + transporter := &rtuSerialTransporter{} + transporter.Address = slavePath + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + + // Don't write anything to master + + ctx := context.Background() + _, err = transporter.Send(ctx, req) + if err == nil { + t.Fatal("Expected timeout error, got nil") + } +} + +func TestRTUSerialTransporter_ReconnectOnMidCommunicationEOF_PTY(t *testing.T) { + master1, slavePath1, err := openPTY() + if err != nil { + t.Skipf("Skipping PTY test: %v", err) + } + defer master1.Close() + + // Request: 01 03 00 00 00 01 84 0A (Read Holding Registers) + // Response: 01 03 02 00 00 B8 44 + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + resp := []byte{0x01, 0x03, 0x02, 0x00, 0x00, 0xB8, 0x44} + partialResp := resp[:len(resp)-1] + var logs bytes.Buffer + + transporter := &rtuSerialTransporter{} + transporter.Address = slavePath1 + transporter.BaudRate = 19200 + transporter.Timeout = 200 * time.Millisecond + transporter.IdleTimeout = serialIdleTimeout + transporter.LinkRecoveryTimeout = 200 * time.Millisecond + transporter.Logger = log.New(&logs, "", 0) + + serverDone := make(chan error, 1) + + go func() { + buf := make([]byte, 1024) + n, err := master1.Read(buf) + if err != nil { + serverDone <- fmt.Errorf("failed to read initial request: %w", err) + return + } + if !bytes.Equal(buf[:n], req) { + serverDone <- fmt.Errorf("unexpected initial request: got %x want %x", buf[:n], req) + return + } + if _, err := master1.Write(partialResp); err != nil { + serverDone <- fmt.Errorf("failed to write partial response: %w", err) + return + } + if err := master1.Close(); err != nil { + serverDone <- fmt.Errorf("failed to close initial PTY master: %w", err) + return + } + serverDone <- nil + }() + + ctx := context.Background() + _, err = transporter.Send(ctx, req) + if err == nil { + t.Fatal("expected Send to fail after reconnect attempt, got nil") + } + + if !strings.Contains(logs.String(), "connection reset, reconnecting") { + t.Fatalf("expected reconnect log entry, got %q", logs.String()) + } + + select { + case err := <-serverDone: + if err != nil { + t.Fatal(err) + } + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for PTY server") + } + + if !strings.Contains(err.Error(), "could not open") && !strings.Contains(err.Error(), "link recovery timeout reached") { + t.Fatalf("expected reconnect-related error, got %v", err) + } +} + +func TestRTUSerialTransporter_PartialResponseThenTimeout(t *testing.T) { + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + resp := []byte{0x01, 0x03, 0x02, 0x00, 0x00, 0xB8, 0x44} + + port := &scriptedPort{ + readData: resp[:len(resp)-2], + readErr: serialpkg.ErrTimeout, + } + + transporter := &rtuSerialTransporter{} + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + + _, err := transporter.Send(context.Background(), req) + if !errors.Is(err, serialpkg.ErrTimeout) { + t.Fatalf("expected timeout after partial response, got %v", err) + } + if got := port.written.Bytes(); !bytes.Equal(got, req) { + t.Fatalf("expected request %x, got %x", req, got) + } +} + +func TestRTUSerialTransporter_RecoveryDisabledOnReadEOF(t *testing.T) { + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + port := &scriptedPort{readErr: io.EOF} + + transporter := &rtuSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = 0 + + _, err := transporter.Send(context.Background(), req) + if err == nil { + t.Fatal("expected link recovery timeout error, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") || !errors.Is(err, io.EOF) { + t.Fatalf("expected link recovery timeout wrapping EOF, got %v", err) + } + if got := port.written.Bytes(); !bytes.Equal(got, req) { + t.Fatalf("expected request %x, got %x", req, got) + } +} + +func TestRTUSerialTransporter_ReconnectBudgetExhaustedOnReadEOF(t *testing.T) { + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + port := &scriptedPort{readErr: io.EOF} + recoveryTimeout := 80 * time.Millisecond + + transporter := &rtuSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = recoveryTimeout + + start := time.Now() + _, err := transporter.Send(context.Background(), req) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected link recovery timeout error, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") { + t.Fatalf("expected link recovery timeout error, got %v", err) + } + if !strings.Contains(err.Error(), "could not open") { + t.Fatalf("expected reconnect open failure to be wrapped, got %v", err) + } + if elapsed < recoveryTimeout-20*time.Millisecond { + t.Fatalf("expected recovery to keep retrying for about %v, returned after %v", recoveryTimeout, elapsed) + } + if !port.closed { + t.Fatal("expected reconnect to close the original port after read EOF") + } + if got := port.written.Bytes(); !bytes.Equal(got, req) { + t.Fatalf("expected request %x, got %x", req, got) + } +} + +func TestRTUSerialTransporter_ReconnectOnWriteEOF(t *testing.T) { + req := []byte{0x01, 0x03, 0x00, 0x00, 0x00, 0x01, 0x84, 0x0A} + port := &scriptedPort{writeErr: io.EOF} + recoveryTimeout := 80 * time.Millisecond + + transporter := &rtuSerialTransporter{} + transporter.Address = filepath.Join(t.TempDir(), "missing-serial") + transporter.port = port + transporter.BaudRate = 19200 + transporter.Timeout = 100 * time.Millisecond + transporter.LinkRecoveryTimeout = recoveryTimeout + + start := time.Now() + _, err := transporter.Send(context.Background(), req) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected reconnect error after write EOF, got nil") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") || !strings.Contains(err.Error(), "could not open") { + t.Fatalf("expected timed-out reconnect open failure, got %v", err) + } + if elapsed < recoveryTimeout-20*time.Millisecond { + t.Fatalf("expected recovery to keep retrying for about %v, returned after %v", recoveryTimeout, elapsed) + } + if !port.closed { + t.Fatal("expected reconnect to close the original port after write EOF") + } +} diff --git a/rtuclient.go b/rtuclient.go index e1bf12d..9874ba4 100644 --- a/rtuclient.go +++ b/rtuclient.go @@ -7,7 +7,6 @@ package modbus import ( "context" "encoding/binary" - "errors" "fmt" "io" "time" @@ -56,6 +55,7 @@ func NewRTUClientHandler(address string) *RTUClientHandler { handler.Address = address handler.Timeout = serialTimeout handler.IdleTimeout = serialIdleTimeout + handler.ReconnectRetryInterval = serialReconnectRetryInterval return handler } @@ -168,7 +168,7 @@ func readIncrementally(slaveID, functionCode byte, r io.Reader, deadline time.Ti for { if time.Now().After(deadline) { // Possible that serialport may spew data - return nil, errors.New("failed to read from serial port within deadline") + return nil, fmt.Errorf("failed to read from serial port within deadline: %w", context.DeadlineExceeded) } if _, err := io.ReadAtLeast(r, buf, 1); err != nil { @@ -264,24 +264,48 @@ func (mb *rtuSerialTransporter) Send(ctx context.Context, aduRequest []byte) (ad mb.lastActivity = time.Now() mb.startCloseTimer() - // Send the request - mb.logf("modbus: send % x\n", aduRequest) - if _, err = mb.port.Write(aduRequest); err != nil { + linkRecoveryDeadline := time.Now().Add(mb.LinkRecoveryTimeout) + + for { + // Send the request + mb.logf("modbus: send % x\n", aduRequest) + if _, err = mb.port.Write(aduRequest); err != nil { + if mb.shouldRecover(err) { + if err = mb.reconnect(ctx, err, linkRecoveryDeadline); err != nil { + return + } + continue + } + + return + } + bytesToRead := calculateResponseLength(aduRequest) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(mb.calculateDelay(len(aduRequest) + bytesToRead)): + } + + connDeadline := time.Now().Add(mb.Timeout) + aduResponse, err = readIncrementally(aduRequest[0], aduRequest[1], mb.port, connDeadline) + if aduResponse != nil { + mb.logf("modbus: recv % x\n", aduResponse[:]) + } + + if err != nil { + if mb.shouldRecover(err) { + if err = mb.reconnect(ctx, err, linkRecoveryDeadline); err != nil { + return + } + continue + } + // Unknown error + mb.logf("modbus: read error: %v", err) + return + } return } - // function := aduRequest[1] - // functionFail := aduRequest[1] & 0x80 - bytesToRead := calculateResponseLength(aduRequest) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(mb.calculateDelay(len(aduRequest) + bytesToRead)): - } - data, err := readIncrementally(aduRequest[0], aduRequest[1], mb.port, time.Now().Add(mb.Config.Timeout)) - mb.logf("modbus: recv % x\n", data[:]) - aduResponse = data - return } // calculateDelay roughly calculates time needed for the next frame. diff --git a/serial.go b/serial.go index 5ad9692..f9199c4 100644 --- a/serial.go +++ b/serial.go @@ -6,6 +6,7 @@ package modbus import ( "context" + "errors" "fmt" "io" "sync" @@ -18,6 +19,8 @@ const ( // Default timeout serialTimeout = 5 * time.Second serialIdleTimeout = 60 * time.Second + // Retry interval while spending the link recovery budget on reconnects. + serialReconnectRetryInterval = 10 * time.Millisecond ) // serialPort has configuration and I/O controller. @@ -25,8 +28,16 @@ type serialPort struct { // Serial port configuration. serial.Config - Logger Logger + Logger Logger + // IdleTimeout is the duration to close the connection when no activity. IdleTimeout time.Duration + // Silent period after successful connection + ConnectDelay time.Duration + // Recovery timeout if the connection is lost + LinkRecoveryTimeout time.Duration + // Interval between reconnect attempts while spending the link recovery budget. + // Zero or negative values fall back to the default retry interval. + ReconnectRetryInterval time.Duration mu sync.Mutex // port is platform-dependent data structure for serial port. @@ -43,6 +54,7 @@ func (mb *serialPort) Connect(ctx context.Context) (err error) { } // connect connects to the serial port if it is not connected. Caller must hold the mutex. +// Note: caller must handle the connection close and recovery if the connection is lost. func (mb *serialPort) connect(ctx context.Context) error { select { case <-ctx.Done(): @@ -52,9 +64,14 @@ func (mb *serialPort) connect(ctx context.Context) error { if mb.port == nil { port, err := serial.Open(&mb.Config) if err != nil { - return fmt.Errorf("could not open %s: %w", mb.Config.Address, err) + return fmt.Errorf("could not open %s: %w", mb.Address, err) } mb.port = port + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(mb.ConnectDelay): //silent period + } } return nil } @@ -81,6 +98,52 @@ func (mb *serialPort) logf(format string, v ...interface{}) { } } +func (mb *serialPort) shouldRecover(err error) bool { + return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) +} + +func (mb *serialPort) reconnect(ctx context.Context, err error, linkRecoveryDeadline time.Time) error { + if mb.LinkRecoveryTimeout == 0 || time.Until(linkRecoveryDeadline) < 0 { + return fmt.Errorf("modbus: link recovery timeout reached: %w", err) + } + + mb.logf("modbus: connection reset, reconnecting") + recoveryErr := err + if cerr := mb.close(); cerr != nil { + recoveryErr = errors.Join(recoveryErr, cerr) + mb.logf("modbus: error closing connection: %v", cerr) + } + + deadlineTimer := time.NewTimer(time.Until(linkRecoveryDeadline)) + defer deadlineTimer.Stop() + retryTicker := time.NewTicker(mb.reconnectRetryInterval()) + defer retryTicker.Stop() + + for { + if cerr := mb.connect(ctx); cerr == nil { + return nil + } else { + recoveryErr = errors.Join(recoveryErr, cerr) + mb.logf("modbus: error reconnecting: %v", cerr) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-deadlineTimer.C: + return fmt.Errorf("modbus: link recovery timeout reached: %w", recoveryErr) + case <-retryTicker.C: + } + } +} + +func (mb *serialPort) reconnectRetryInterval() time.Duration { + if mb.ReconnectRetryInterval > 0 { + return mb.ReconnectRetryInterval + } + return serialReconnectRetryInterval +} + func (mb *serialPort) startCloseTimer() { if mb.IdleTimeout <= 0 { return @@ -103,6 +166,6 @@ func (mb *serialPort) closeIdle() { if idle := time.Since(mb.lastActivity); idle >= mb.IdleTimeout { mb.logf("modbus: closing connection due to idle timeout: %v", idle) - mb.close() + _ = mb.close() } } diff --git a/serial_test.go b/serial_test.go index 226ebb5..86cd010 100644 --- a/serial_test.go +++ b/serial_test.go @@ -2,7 +2,13 @@ package modbus import ( "bytes" + "context" + "errors" "io" + "log" + "os" + "path/filepath" + "strings" "sync/atomic" "testing" "time" @@ -19,6 +25,18 @@ func (n *nopCloser) Close() error { return nil } +type errCloser struct { + io.ReadWriter + + closed atomic.Bool + err error +} + +func (e *errCloser) Close() error { + e.closed.Store(true) + return e.err +} + func TestSerialCloseIdle(t *testing.T) { port := &nopCloser{ ReadWriter: &bytes.Buffer{}, @@ -37,3 +55,128 @@ func TestSerialCloseIdle(t *testing.T) { t.Fatalf("serial port is not closed when inactivity: %+v", port) } } + +func TestSerialReconnect_UsesConfiguredRetryInterval(t *testing.T) { + var logs bytes.Buffer + port := &nopCloser{ReadWriter: &bytes.Buffer{}} + recoveryTimeout := 40 * time.Millisecond + + s := serialPort{ + Logger: log.New(&logs, "", 0), + port: port, + LinkRecoveryTimeout: recoveryTimeout, + ReconnectRetryInterval: 50 * time.Millisecond, + } + s.Address = filepath.Join(t.TempDir(), "missing-serial") + + err := s.reconnect(context.Background(), io.EOF, time.Now().Add(recoveryTimeout)) + if err == nil { + t.Fatal("expected reconnect to fail when the serial device is missing") + } + if !strings.Contains(err.Error(), "link recovery timeout reached") { + t.Fatalf("expected link recovery timeout error, got %v", err) + } + if count := strings.Count(logs.String(), "error reconnecting"); count != 1 { + t.Fatalf("expected exactly one reconnect attempt before timeout, got %d logs: %q", count, logs.String()) + } + if !port.closed.Load() || s.port != nil { + t.Fatalf("expected reconnect to close the original port: closed=%v port=%v", port.closed.Load(), s.port) + } +} + +func TestSerialReconnect_DefaultRetryIntervalRetriesMultipleTimes(t *testing.T) { + var logs bytes.Buffer + recoveryTimeout := 45 * time.Millisecond + + s := serialPort{ + Logger: log.New(&logs, "", 0), + port: &nopCloser{ReadWriter: &bytes.Buffer{}}, + LinkRecoveryTimeout: recoveryTimeout, + } + s.Address = filepath.Join(t.TempDir(), "missing-serial") + + err := s.reconnect(context.Background(), io.EOF, time.Now().Add(recoveryTimeout)) + if err == nil { + t.Fatal("expected reconnect to fail when the serial device is missing") + } + if !errors.Is(err, io.EOF) { + t.Fatalf("expected reconnect to preserve the original EOF, got %v", err) + } + if count := strings.Count(logs.String(), "error reconnecting"); count < 2 { + t.Fatalf("expected default retry interval to attempt reconnect multiple times, got %d logs: %q", count, logs.String()) + } + if count := strings.Count(err.Error(), "could not open"); count < 2 { + t.Fatalf("expected reconnect error to include multiple open failures, got %d in %q", count, err.Error()) + } +} + +func TestSerialReconnectHotPlug_EventuallySucceedsWithinRecoveryWindow_PTY(t *testing.T) { + var logs bytes.Buffer + + initialPort := &errCloser{ + ReadWriter: &bytes.Buffer{}, + err: errors.New("device disappeared"), + } + recoveryTimeout := 200 * time.Millisecond + stablePath := filepath.Join(t.TempDir(), "recovering-serial") + + type reopenResult struct { + master *os.File + err error + } + + reopenReady := make(chan reopenResult, 1) + go func() { + time.Sleep(45 * time.Millisecond) + + master, slavePath, err := openPTY() + if err != nil { + reopenReady <- reopenResult{err: err} + return + } + if err := os.Symlink(slavePath, stablePath); err != nil { + _ = master.Close() + reopenReady <- reopenResult{err: err} + return + } + reopenReady <- reopenResult{master: master} + }() + + s := serialPort{ + Logger: log.New(&logs, "", 0), + port: initialPort, + LinkRecoveryTimeout: recoveryTimeout, + ReconnectRetryInterval: 10 * time.Millisecond, + } + s.Address = stablePath + s.BaudRate = 19200 + s.Timeout = 50 * time.Millisecond + + err := s.reconnect(context.Background(), io.EOF, time.Now().Add(recoveryTimeout)) + if err != nil { + t.Fatalf("expected reconnect to succeed before timeout, got %v", err) + } + if s.port == nil || s.port == initialPort { + t.Fatalf("expected reconnect to replace the original port, got %v", s.port) + } + + select { + case result := <-reopenReady: + if result.err != nil { + t.Fatal(result.err) + } + t.Cleanup(func() { + _ = s.Close() + _ = result.master.Close() + }) + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for PTY reopen setup") + } + + if !strings.Contains(logs.String(), "error closing connection") { + t.Fatalf("expected close error to be logged, got %q", logs.String()) + } + if count := strings.Count(logs.String(), "error reconnecting"); count < 1 { + t.Fatalf("expected reconnect to log failed reopen attempts before success, got %d logs: %q", count, logs.String()) + } +} diff --git a/tcpclient.go b/tcpclient.go index d6fcbe0..0ff0c15 100644 --- a/tcpclient.go +++ b/tcpclient.go @@ -302,6 +302,11 @@ func (mb *tcpTransporter) Send(ctx context.Context, aduRequest []byte) (aduRespo } } +func (mb *tcpTransporter) shouldRecover(err error) bool { + return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) +} + + func (mb *tcpTransporter) readResponse(aduRequest []byte, data []byte, recoveryDeadline time.Time, protocolDeadline time.Time) (aduResponse []byte, res readResult, err error) { // res is readResultDone by default, which either means we succeeded or err contains the fatal error for { @@ -310,7 +315,7 @@ func (mb *tcpTransporter) readResponse(aduRequest []byte, data []byte, recoveryD if mb.LinkRecoveryTimeout == 0 || time.Until(recoveryDeadline) < 0 { return } - if err == io.EOF || err == io.ErrUnexpectedEOF || err == syscall.ECONNRESET { + if mb.shouldRecover(err) { mb.logf("modbus: connection closed by remote side: %v", err) res = readResultCloseRetry } @@ -322,7 +327,7 @@ func (mb *tcpTransporter) readResponse(aduRequest []byte, data []byte, recoveryD if mb.LinkRecoveryTimeout == 0 || time.Until(recoveryDeadline) < 0 { return } - if err == io.EOF || err == io.ErrUnexpectedEOF || err == syscall.ECONNRESET { + if mb.shouldRecover(err) { mb.logf("modbus: connection closed by remote side: %v", err) res = readResultCloseRetry return