@@ -6,6 +6,7 @@ package m3ua
66
77import (
88 "fmt"
9+ "math/rand"
910 "net"
1011 "sync"
1112 "time"
@@ -24,6 +25,9 @@ const (
2425
2526// Conn represents a M3UA connection, which satisfies standard net.Conn interface.
2627type Conn struct {
28+ // maxMessageStreamID is the maximum negotiated sctp stream ID used,
29+ // must not be zero, must vary from 1 to maxMessageStreamID
30+ maxMessageStreamID uint16
2731 // muState is to Lock when updating state
2832 muState * sync.RWMutex
2933 // mode represents the endpoint works as client or server
@@ -100,7 +104,9 @@ func (c *Conn) ReadPD() (pd *params.ProtocolDataPayload, err error) {
100104
101105// Write writes data to the connection.
102106func (c * Conn ) Write (b []byte ) (n int , err error ) {
103- return c .WriteToStream (b , c .StreamID ())
107+ stream := c .chooseStreamID ()
108+
109+ return c .WriteToStream (b , stream )
104110}
105111
106112// WriteToStream writes data to the connection and specific stream
@@ -133,7 +139,9 @@ func (c *Conn) WriteToStream(b []byte, streamID uint16) (n int, err error) {
133139
134140// WritePD writes data with a specific mtp3 protocol data to the connection.
135141func (c * Conn ) WritePD (protocolData * params.Param ) (n int , err error ) {
136- return c .WritePDToStream (protocolData , c .StreamID ())
142+ stream := c .chooseStreamID ()
143+
144+ return c .WritePDToStream (protocolData , stream )
137145}
138146
139147// WritePDToStream writes data with a specific mtp3 protocol data to the connection and specific stream
@@ -238,3 +246,19 @@ func (c *Conn) State() State {
238246func (c * Conn ) StreamID () uint16 {
239247 return c .sctpInfo .Stream
240248}
249+
250+ // MaxMessageStreamID returns the maximum negotiated sctp stream ID
251+ // The streamID for sending a message must start from 1 up to maxMessageStreamID, 0 is reserved for management messages
252+ func (c * Conn ) MaxMessageStreamID () uint16 {
253+ return c .maxMessageStreamID
254+ }
255+
256+ // chooseStreamID generates a random uint16 from 1 to max (inclusive)
257+ func (c * Conn ) chooseStreamID () uint16 {
258+ if c .maxMessageStreamID == 1 {
259+ return 1
260+ }
261+ r := rand .New (rand .NewSource (time .Now ().UnixNano ()))
262+ randomNum := uint16 (r .Intn (int (c .maxMessageStreamID )))
263+ return randomNum + 1
264+ }
0 commit comments