Code refactoring for bpa operator
[icn.git] / cmd / bpa-operator / vendor / golang.org / x / crypto / ssh / handshake.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "crypto/rand"
9         "errors"
10         "fmt"
11         "io"
12         "log"
13         "net"
14         "sync"
15 )
16
17 // debugHandshake, if set, prints messages sent and received.  Key
18 // exchange messages are printed as if DH were used, so the debug
19 // messages are wrong when using ECDH.
20 const debugHandshake = false
21
22 // chanSize sets the amount of buffering SSH connections. This is
23 // primarily for testing: setting chanSize=0 uncovers deadlocks more
24 // quickly.
25 const chanSize = 16
26
27 // keyingTransport is a packet based transport that supports key
28 // changes. It need not be thread-safe. It should pass through
29 // msgNewKeys in both directions.
30 type keyingTransport interface {
31         packetConn
32
33         // prepareKeyChange sets up a key change. The key change for a
34         // direction will be effected if a msgNewKeys message is sent
35         // or received.
36         prepareKeyChange(*algorithms, *kexResult) error
37 }
38
39 // handshakeTransport implements rekeying on top of a keyingTransport
40 // and offers a thread-safe writePacket() interface.
41 type handshakeTransport struct {
42         conn   keyingTransport
43         config *Config
44
45         serverVersion []byte
46         clientVersion []byte
47
48         // hostKeys is non-empty if we are the server. In that case,
49         // it contains all host keys that can be used to sign the
50         // connection.
51         hostKeys []Signer
52
53         // hostKeyAlgorithms is non-empty if we are the client. In that case,
54         // we accept these key types from the server as host key.
55         hostKeyAlgorithms []string
56
57         // On read error, incoming is closed, and readError is set.
58         incoming  chan []byte
59         readError error
60
61         mu             sync.Mutex
62         writeError     error
63         sentInitPacket []byte
64         sentInitMsg    *kexInitMsg
65         pendingPackets [][]byte // Used when a key exchange is in progress.
66
67         // If the read loop wants to schedule a kex, it pings this
68         // channel, and the write loop will send out a kex
69         // message.
70         requestKex chan struct{}
71
72         // If the other side requests or confirms a kex, its kexInit
73         // packet is sent here for the write loop to find it.
74         startKex chan *pendingKex
75
76         // data for host key checking
77         hostKeyCallback HostKeyCallback
78         dialAddress     string
79         remoteAddr      net.Addr
80
81         // bannerCallback is non-empty if we are the client and it has been set in
82         // ClientConfig. In that case it is called during the user authentication
83         // dance to handle a custom server's message.
84         bannerCallback BannerCallback
85
86         // Algorithms agreed in the last key exchange.
87         algorithms *algorithms
88
89         readPacketsLeft uint32
90         readBytesLeft   int64
91
92         writePacketsLeft uint32
93         writeBytesLeft   int64
94
95         // The session ID or nil if first kex did not complete yet.
96         sessionID []byte
97 }
98
99 type pendingKex struct {
100         otherInit []byte
101         done      chan error
102 }
103
104 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
105         t := &handshakeTransport{
106                 conn:          conn,
107                 serverVersion: serverVersion,
108                 clientVersion: clientVersion,
109                 incoming:      make(chan []byte, chanSize),
110                 requestKex:    make(chan struct{}, 1),
111                 startKex:      make(chan *pendingKex, 1),
112
113                 config: config,
114         }
115         t.resetReadThresholds()
116         t.resetWriteThresholds()
117
118         // We always start with a mandatory key exchange.
119         t.requestKex <- struct{}{}
120         return t
121 }
122
123 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
124         t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
125         t.dialAddress = dialAddr
126         t.remoteAddr = addr
127         t.hostKeyCallback = config.HostKeyCallback
128         t.bannerCallback = config.BannerCallback
129         if config.HostKeyAlgorithms != nil {
130                 t.hostKeyAlgorithms = config.HostKeyAlgorithms
131         } else {
132                 t.hostKeyAlgorithms = supportedHostKeyAlgos
133         }
134         go t.readLoop()
135         go t.kexLoop()
136         return t
137 }
138
139 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
140         t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
141         t.hostKeys = config.hostKeys
142         go t.readLoop()
143         go t.kexLoop()
144         return t
145 }
146
147 func (t *handshakeTransport) getSessionID() []byte {
148         return t.sessionID
149 }
150
151 // waitSession waits for the session to be established. This should be
152 // the first thing to call after instantiating handshakeTransport.
153 func (t *handshakeTransport) waitSession() error {
154         p, err := t.readPacket()
155         if err != nil {
156                 return err
157         }
158         if p[0] != msgNewKeys {
159                 return fmt.Errorf("ssh: first packet should be msgNewKeys")
160         }
161
162         return nil
163 }
164
165 func (t *handshakeTransport) id() string {
166         if len(t.hostKeys) > 0 {
167                 return "server"
168         }
169         return "client"
170 }
171
172 func (t *handshakeTransport) printPacket(p []byte, write bool) {
173         action := "got"
174         if write {
175                 action = "sent"
176         }
177
178         if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
179                 log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
180         } else {
181                 msg, err := decode(p)
182                 log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
183         }
184 }
185
186 func (t *handshakeTransport) readPacket() ([]byte, error) {
187         p, ok := <-t.incoming
188         if !ok {
189                 return nil, t.readError
190         }
191         return p, nil
192 }
193
194 func (t *handshakeTransport) readLoop() {
195         first := true
196         for {
197                 p, err := t.readOnePacket(first)
198                 first = false
199                 if err != nil {
200                         t.readError = err
201                         close(t.incoming)
202                         break
203                 }
204                 if p[0] == msgIgnore || p[0] == msgDebug {
205                         continue
206                 }
207                 t.incoming <- p
208         }
209
210         // Stop writers too.
211         t.recordWriteError(t.readError)
212
213         // Unblock the writer should it wait for this.
214         close(t.startKex)
215
216         // Don't close t.requestKex; it's also written to from writePacket.
217 }
218
219 func (t *handshakeTransport) pushPacket(p []byte) error {
220         if debugHandshake {
221                 t.printPacket(p, true)
222         }
223         return t.conn.writePacket(p)
224 }
225
226 func (t *handshakeTransport) getWriteError() error {
227         t.mu.Lock()
228         defer t.mu.Unlock()
229         return t.writeError
230 }
231
232 func (t *handshakeTransport) recordWriteError(err error) {
233         t.mu.Lock()
234         defer t.mu.Unlock()
235         if t.writeError == nil && err != nil {
236                 t.writeError = err
237         }
238 }
239
240 func (t *handshakeTransport) requestKeyExchange() {
241         select {
242         case t.requestKex <- struct{}{}:
243         default:
244                 // something already requested a kex, so do nothing.
245         }
246 }
247
248 func (t *handshakeTransport) resetWriteThresholds() {
249         t.writePacketsLeft = packetRekeyThreshold
250         if t.config.RekeyThreshold > 0 {
251                 t.writeBytesLeft = int64(t.config.RekeyThreshold)
252         } else if t.algorithms != nil {
253                 t.writeBytesLeft = t.algorithms.w.rekeyBytes()
254         } else {
255                 t.writeBytesLeft = 1 << 30
256         }
257 }
258
259 func (t *handshakeTransport) kexLoop() {
260
261 write:
262         for t.getWriteError() == nil {
263                 var request *pendingKex
264                 var sent bool
265
266                 for request == nil || !sent {
267                         var ok bool
268                         select {
269                         case request, ok = <-t.startKex:
270                                 if !ok {
271                                         break write
272                                 }
273                         case <-t.requestKex:
274                                 break
275                         }
276
277                         if !sent {
278                                 if err := t.sendKexInit(); err != nil {
279                                         t.recordWriteError(err)
280                                         break
281                                 }
282                                 sent = true
283                         }
284                 }
285
286                 if err := t.getWriteError(); err != nil {
287                         if request != nil {
288                                 request.done <- err
289                         }
290                         break
291                 }
292
293                 // We're not servicing t.requestKex, but that is OK:
294                 // we never block on sending to t.requestKex.
295
296                 // We're not servicing t.startKex, but the remote end
297                 // has just sent us a kexInitMsg, so it can't send
298                 // another key change request, until we close the done
299                 // channel on the pendingKex request.
300
301                 err := t.enterKeyExchange(request.otherInit)
302
303                 t.mu.Lock()
304                 t.writeError = err
305                 t.sentInitPacket = nil
306                 t.sentInitMsg = nil
307
308                 t.resetWriteThresholds()
309
310                 // we have completed the key exchange. Since the
311                 // reader is still blocked, it is safe to clear out
312                 // the requestKex channel. This avoids the situation
313                 // where: 1) we consumed our own request for the
314                 // initial kex, and 2) the kex from the remote side
315                 // caused another send on the requestKex channel,
316         clear:
317                 for {
318                         select {
319                         case <-t.requestKex:
320                                 //
321                         default:
322                                 break clear
323                         }
324                 }
325
326                 request.done <- t.writeError
327
328                 // kex finished. Push packets that we received while
329                 // the kex was in progress. Don't look at t.startKex
330                 // and don't increment writtenSinceKex: if we trigger
331                 // another kex while we are still busy with the last
332                 // one, things will become very confusing.
333                 for _, p := range t.pendingPackets {
334                         t.writeError = t.pushPacket(p)
335                         if t.writeError != nil {
336                                 break
337                         }
338                 }
339                 t.pendingPackets = t.pendingPackets[:0]
340                 t.mu.Unlock()
341         }
342
343         // drain startKex channel. We don't service t.requestKex
344         // because nobody does blocking sends there.
345         go func() {
346                 for init := range t.startKex {
347                         init.done <- t.writeError
348                 }
349         }()
350
351         // Unblock reader.
352         t.conn.Close()
353 }
354
355 // The protocol uses uint32 for packet counters, so we can't let them
356 // reach 1<<32.  We will actually read and write more packets than
357 // this, though: the other side may send more packets, and after we
358 // hit this limit on writing we will send a few more packets for the
359 // key exchange itself.
360 const packetRekeyThreshold = (1 << 31)
361
362 func (t *handshakeTransport) resetReadThresholds() {
363         t.readPacketsLeft = packetRekeyThreshold
364         if t.config.RekeyThreshold > 0 {
365                 t.readBytesLeft = int64(t.config.RekeyThreshold)
366         } else if t.algorithms != nil {
367                 t.readBytesLeft = t.algorithms.r.rekeyBytes()
368         } else {
369                 t.readBytesLeft = 1 << 30
370         }
371 }
372
373 func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
374         p, err := t.conn.readPacket()
375         if err != nil {
376                 return nil, err
377         }
378
379         if t.readPacketsLeft > 0 {
380                 t.readPacketsLeft--
381         } else {
382                 t.requestKeyExchange()
383         }
384
385         if t.readBytesLeft > 0 {
386                 t.readBytesLeft -= int64(len(p))
387         } else {
388                 t.requestKeyExchange()
389         }
390
391         if debugHandshake {
392                 t.printPacket(p, false)
393         }
394
395         if first && p[0] != msgKexInit {
396                 return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
397         }
398
399         if p[0] != msgKexInit {
400                 return p, nil
401         }
402
403         firstKex := t.sessionID == nil
404
405         kex := pendingKex{
406                 done:      make(chan error, 1),
407                 otherInit: p,
408         }
409         t.startKex <- &kex
410         err = <-kex.done
411
412         if debugHandshake {
413                 log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
414         }
415
416         if err != nil {
417                 return nil, err
418         }
419
420         t.resetReadThresholds()
421
422         // By default, a key exchange is hidden from higher layers by
423         // translating it into msgIgnore.
424         successPacket := []byte{msgIgnore}
425         if firstKex {
426                 // sendKexInit() for the first kex waits for
427                 // msgNewKeys so the authentication process is
428                 // guaranteed to happen over an encrypted transport.
429                 successPacket = []byte{msgNewKeys}
430         }
431
432         return successPacket, nil
433 }
434
435 // sendKexInit sends a key change message.
436 func (t *handshakeTransport) sendKexInit() error {
437         t.mu.Lock()
438         defer t.mu.Unlock()
439         if t.sentInitMsg != nil {
440                 // kexInits may be sent either in response to the other side,
441                 // or because our side wants to initiate a key change, so we
442                 // may have already sent a kexInit. In that case, don't send a
443                 // second kexInit.
444                 return nil
445         }
446
447         msg := &kexInitMsg{
448                 KexAlgos:                t.config.KeyExchanges,
449                 CiphersClientServer:     t.config.Ciphers,
450                 CiphersServerClient:     t.config.Ciphers,
451                 MACsClientServer:        t.config.MACs,
452                 MACsServerClient:        t.config.MACs,
453                 CompressionClientServer: supportedCompressions,
454                 CompressionServerClient: supportedCompressions,
455         }
456         io.ReadFull(rand.Reader, msg.Cookie[:])
457
458         if len(t.hostKeys) > 0 {
459                 for _, k := range t.hostKeys {
460                         msg.ServerHostKeyAlgos = append(
461                                 msg.ServerHostKeyAlgos, k.PublicKey().Type())
462                 }
463         } else {
464                 msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
465         }
466         packet := Marshal(msg)
467
468         // writePacket destroys the contents, so save a copy.
469         packetCopy := make([]byte, len(packet))
470         copy(packetCopy, packet)
471
472         if err := t.pushPacket(packetCopy); err != nil {
473                 return err
474         }
475
476         t.sentInitMsg = msg
477         t.sentInitPacket = packet
478
479         return nil
480 }
481
482 func (t *handshakeTransport) writePacket(p []byte) error {
483         switch p[0] {
484         case msgKexInit:
485                 return errors.New("ssh: only handshakeTransport can send kexInit")
486         case msgNewKeys:
487                 return errors.New("ssh: only handshakeTransport can send newKeys")
488         }
489
490         t.mu.Lock()
491         defer t.mu.Unlock()
492         if t.writeError != nil {
493                 return t.writeError
494         }
495
496         if t.sentInitMsg != nil {
497                 // Copy the packet so the writer can reuse the buffer.
498                 cp := make([]byte, len(p))
499                 copy(cp, p)
500                 t.pendingPackets = append(t.pendingPackets, cp)
501                 return nil
502         }
503
504         if t.writeBytesLeft > 0 {
505                 t.writeBytesLeft -= int64(len(p))
506         } else {
507                 t.requestKeyExchange()
508         }
509
510         if t.writePacketsLeft > 0 {
511                 t.writePacketsLeft--
512         } else {
513                 t.requestKeyExchange()
514         }
515
516         if err := t.pushPacket(p); err != nil {
517                 t.writeError = err
518         }
519
520         return nil
521 }
522
523 func (t *handshakeTransport) Close() error {
524         return t.conn.Close()
525 }
526
527 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
528         if debugHandshake {
529                 log.Printf("%s entered key exchange", t.id())
530         }
531
532         otherInit := &kexInitMsg{}
533         if err := Unmarshal(otherInitPacket, otherInit); err != nil {
534                 return err
535         }
536
537         magics := handshakeMagics{
538                 clientVersion: t.clientVersion,
539                 serverVersion: t.serverVersion,
540                 clientKexInit: otherInitPacket,
541                 serverKexInit: t.sentInitPacket,
542         }
543
544         clientInit := otherInit
545         serverInit := t.sentInitMsg
546         if len(t.hostKeys) == 0 {
547                 clientInit, serverInit = serverInit, clientInit
548
549                 magics.clientKexInit = t.sentInitPacket
550                 magics.serverKexInit = otherInitPacket
551         }
552
553         var err error
554         t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
555         if err != nil {
556                 return err
557         }
558
559         // We don't send FirstKexFollows, but we handle receiving it.
560         //
561         // RFC 4253 section 7 defines the kex and the agreement method for
562         // first_kex_packet_follows. It states that the guessed packet
563         // should be ignored if the "kex algorithm and/or the host
564         // key algorithm is guessed wrong (server and client have
565         // different preferred algorithm), or if any of the other
566         // algorithms cannot be agreed upon". The other algorithms have
567         // already been checked above so the kex algorithm and host key
568         // algorithm are checked here.
569         if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
570                 // other side sent a kex message for the wrong algorithm,
571                 // which we have to ignore.
572                 if _, err := t.conn.readPacket(); err != nil {
573                         return err
574                 }
575         }
576
577         kex, ok := kexAlgoMap[t.algorithms.kex]
578         if !ok {
579                 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
580         }
581
582         var result *kexResult
583         if len(t.hostKeys) > 0 {
584                 result, err = t.server(kex, t.algorithms, &magics)
585         } else {
586                 result, err = t.client(kex, t.algorithms, &magics)
587         }
588
589         if err != nil {
590                 return err
591         }
592
593         if t.sessionID == nil {
594                 t.sessionID = result.H
595         }
596         result.SessionID = t.sessionID
597
598         if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
599                 return err
600         }
601         if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
602                 return err
603         }
604         if packet, err := t.conn.readPacket(); err != nil {
605                 return err
606         } else if packet[0] != msgNewKeys {
607                 return unexpectedMessageError(msgNewKeys, packet[0])
608         }
609
610         return nil
611 }
612
613 func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
614         var hostKey Signer
615         for _, k := range t.hostKeys {
616                 if algs.hostKey == k.PublicKey().Type() {
617                         hostKey = k
618                 }
619         }
620
621         r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
622         return r, err
623 }
624
625 func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
626         result, err := kex.Client(t.conn, t.config.Rand, magics)
627         if err != nil {
628                 return nil, err
629         }
630
631         hostKey, err := ParsePublicKey(result.HostKey)
632         if err != nil {
633                 return nil, err
634         }
635
636         if err := verifyHostKeySignature(hostKey, result); err != nil {
637                 return nil, err
638         }
639
640         err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
641         if err != nil {
642                 return nil, err
643         }
644
645         return result, nil
646 }