1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
16 // debugMux, if set, causes messages in the connection protocol to be
18 const debugMux = false
20 // chanList is a thread safe channel list.
21 type chanList struct {
22 // protects concurrent access to chans
25 // chans are indexed by the local id of the channel, which the
26 // other side should send in the PeersId field.
29 // This is a debugging aid: it offsets all IDs by this
30 // amount. This helps distinguish otherwise identical
31 // server/client muxes
35 // Assigns a channel ID to the given channel.
36 func (c *chanList) add(ch *channel) uint32 {
39 for i := range c.chans {
40 if c.chans[i] == nil {
42 return uint32(i) + c.offset
45 c.chans = append(c.chans, ch)
46 return uint32(len(c.chans)-1) + c.offset
49 // getChan returns the channel for the given ID.
50 func (c *chanList) getChan(id uint32) *channel {
55 if id < uint32(len(c.chans)) {
61 func (c *chanList) remove(id uint32) {
64 if id < uint32(len(c.chans)) {
70 // dropAll forgets all channels it knows, returning them in a slice.
71 func (c *chanList) dropAll() []*channel {
76 for _, ch := range c.chans {
86 // mux represents the state for the SSH connection protocol, which
87 // multiplexes many channels onto a single packet transport.
92 incomingChannels chan NewChannel
94 globalSentMu sync.Mutex
95 globalResponses chan interface{}
96 incomingRequests chan *Request
102 // When debugging, each new chanList instantiation has a different
106 func (m *mux) Wait() error {
108 defer m.errCond.L.Unlock()
115 // newMux returns a mux that runs over the given connection.
116 func newMux(p packetConn) *mux {
119 incomingChannels: make(chan NewChannel, chanSize),
120 globalResponses: make(chan interface{}, 1),
121 incomingRequests: make(chan *Request, chanSize),
125 m.chanList.offset = atomic.AddUint32(&globalOff, 1)
132 func (m *mux) sendMessage(msg interface{}) error {
135 log.Printf("send global(%d): %#v", m.chanList.offset, msg)
137 return m.conn.writePacket(p)
140 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
142 m.globalSentMu.Lock()
143 defer m.globalSentMu.Unlock()
146 if err := m.sendMessage(globalRequestMsg{
148 WantReply: wantReply,
151 return false, nil, err
155 return false, nil, nil
158 msg, ok := <-m.globalResponses
160 return false, nil, io.EOF
162 switch msg := msg.(type) {
163 case *globalRequestFailureMsg:
164 return false, msg.Data, nil
165 case *globalRequestSuccessMsg:
166 return true, msg.Data, nil
168 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
172 // ackRequest must be called after processing a global request that
173 // has WantReply set.
174 func (m *mux) ackRequest(ok bool, data []byte) error {
176 return m.sendMessage(globalRequestSuccessMsg{Data: data})
178 return m.sendMessage(globalRequestFailureMsg{Data: data})
181 func (m *mux) Close() error {
182 return m.conn.Close()
185 // loop runs the connection machine. It will process packets until an
186 // error is encountered. To synchronize on loop exit, use mux.Wait.
187 func (m *mux) loop() {
193 for _, ch := range m.chanList.dropAll() {
197 close(m.incomingChannels)
198 close(m.incomingRequests)
199 close(m.globalResponses)
205 m.errCond.Broadcast()
209 log.Println("loop exit", err)
213 // onePacket reads and processes one packet.
214 func (m *mux) onePacket() error {
215 packet, err := m.conn.readPacket()
221 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
222 log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
224 p, _ := decode(packet)
225 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
231 return m.handleChannelOpen(packet)
232 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
233 return m.handleGlobalPacket(packet)
236 // assume a channel packet.
238 return parseError(packet[0])
240 id := binary.BigEndian.Uint32(packet[1:])
241 ch := m.chanList.getChan(id)
243 return fmt.Errorf("ssh: invalid channel %d", id)
246 return ch.handlePacket(packet)
249 func (m *mux) handleGlobalPacket(packet []byte) error {
250 msg, err := decode(packet)
255 switch msg := msg.(type) {
256 case *globalRequestMsg:
257 m.incomingRequests <- &Request{
259 WantReply: msg.WantReply,
263 case *globalRequestSuccessMsg, *globalRequestFailureMsg:
264 m.globalResponses <- msg
266 panic(fmt.Sprintf("not a global message %#v", msg))
272 // handleChannelOpen schedules a channel to be Accept()ed.
273 func (m *mux) handleChannelOpen(packet []byte) error {
274 var msg channelOpenMsg
275 if err := Unmarshal(packet, &msg); err != nil {
279 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
280 failMsg := channelOpenFailureMsg{
281 PeersID: msg.PeersID,
282 Reason: ConnectionFailed,
283 Message: "invalid request",
284 Language: "en_US.UTF-8",
286 return m.sendMessage(failMsg)
289 c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
290 c.remoteId = msg.PeersID
291 c.maxRemotePayload = msg.MaxPacketSize
292 c.remoteWin.add(msg.PeersWindow)
293 m.incomingChannels <- c
297 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
298 ch, err := m.openChannel(chanType, extra)
303 return ch, ch.incomingRequests, nil
306 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
307 ch := m.newChannel(chanType, channelOutbound, extra)
309 ch.maxIncomingPayload = channelMaxPacket
311 open := channelOpenMsg{
313 PeersWindow: ch.myWindow,
314 MaxPacketSize: ch.maxIncomingPayload,
315 TypeSpecificData: extra,
318 if err := m.sendMessage(open); err != nil {
322 switch msg := (<-ch.msg).(type) {
323 case *channelOpenConfirmMsg:
325 case *channelOpenFailureMsg:
326 return nil, &OpenChannelError{msg.Reason, msg.Message}
328 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)