@@ -72,17 +72,19 @@ func NewCommInstanceWithServer(port int, idMapper identity.Mapper, peerIdentity
72
72
var ll net.Listener
73
73
var s * grpc.Server
74
74
var secOpt grpc.DialOption
75
+ var certHash []byte
75
76
76
77
if len (dialOpts ) == 0 {
77
78
dialOpts = []grpc.DialOption {grpc .WithTimeout (dialTimeout )}
78
79
}
79
80
80
81
if port > 0 {
81
- s , ll , secOpt = createGRPCLayer (port )
82
+ s , ll , secOpt , certHash = createGRPCLayer (port )
82
83
dialOpts = append (dialOpts , secOpt )
83
84
}
84
85
85
86
commInst := & commImpl {
87
+ selfCertHash : certHash ,
86
88
PKIID : idMapper .GetPKIidOfCert (peerIdentity ),
87
89
idMapper : idMapper ,
88
90
logger : util .GetLogger (util .LOGGING_COMM_MODULE , fmt .Sprintf ("%d" , port )),
@@ -117,16 +119,28 @@ func NewCommInstanceWithServer(port int, idMapper identity.Mapper, peerIdentity
117
119
}
118
120
119
121
// NewCommInstance creates a new comm instance that binds itself to the given gRPC server
120
- func NewCommInstance (s * grpc.Server , idStore identity.Mapper , peerIdentity api.PeerIdentityType , dialOpts ... grpc.DialOption ) (Comm , error ) {
122
+ func NewCommInstance (s * grpc.Server , cert * tls. Certificate , idStore identity.Mapper , peerIdentity api.PeerIdentityType , dialOpts ... grpc.DialOption ) (Comm , error ) {
121
123
commInst , err := NewCommInstanceWithServer (- 1 , idStore , peerIdentity , dialOpts ... )
122
124
if err != nil {
123
125
return nil , err
124
126
}
127
+
128
+ if cert != nil {
129
+ inst := commInst .(* commImpl )
130
+ if len (cert .Certificate ) == 0 {
131
+ inst .logger .Panic ("Certificate supplied but certificate chain is empty" )
132
+ } else {
133
+ inst .selfCertHash = certHashFromRawCert (cert .Certificate [0 ])
134
+ }
135
+ }
136
+
125
137
proto .RegisterGossipServer (s , commInst .(* commImpl ))
138
+
126
139
return commInst , nil
127
140
}
128
141
129
142
type commImpl struct {
143
+ selfCertHash []byte
130
144
peerIdentity api.PeerIdentityType
131
145
idMapper identity.Mapper
132
146
logger * util.Logger
@@ -373,13 +387,16 @@ func extractRemoteAddress(stream stream) string {
373
387
func (c * commImpl ) authenticateRemotePeer (stream stream ) (common.PKIidType , error ) {
374
388
ctx := stream .Context ()
375
389
remoteAddress := extractRemoteAddress (stream )
376
- tlsUnique := ExtractTLSUnique (ctx )
390
+ remoteCertHash := extractCertificateHashFromContext (ctx )
377
391
var sig []byte
378
392
var err error
379
- if tlsUnique != nil {
380
- sig , err = c .idMapper .Sign (tlsUnique )
393
+
394
+ // If TLS is detected, sign the hash of our cert to bind our TLS cert
395
+ // to the gRPC session
396
+ if remoteCertHash != nil && c .selfCertHash != nil {
397
+ sig , err = c .idMapper .Sign (c .selfCertHash )
381
398
if err != nil {
382
- c .logger .Error ("Failed signing TLS-Unique :" , err )
399
+ c .logger .Error ("Failed signing self certificate hash :" , err )
383
400
return nil , err
384
401
}
385
402
}
@@ -414,8 +431,9 @@ func (c *commImpl) authenticateRemotePeer(stream stream) (common.PKIidType, erro
414
431
return nil , err
415
432
}
416
433
417
- if tlsUnique != nil {
418
- err = c .idMapper .Verify (receivedMsg .PkiID , receivedMsg .Sig , tlsUnique )
434
+ // if TLS is detected, verify remote peer
435
+ if remoteCertHash != nil && c .selfCertHash != nil {
436
+ err = c .idMapper .Verify (receivedMsg .PkiID , receivedMsg .Sig , remoteCertHash )
419
437
if err != nil {
420
438
c .logger .Error ("Failed verifying signature from" , remoteAddress , ":" , err )
421
439
return nil , err
@@ -424,7 +442,6 @@ func (c *commImpl) authenticateRemotePeer(stream stream) (common.PKIidType, erro
424
442
425
443
c .logger .Debug ("Authenticated" , remoteAddress )
426
444
return receivedMsg .PkiID , nil
427
-
428
445
}
429
446
430
447
func (c * commImpl ) GossipStream (stream proto.Gossip_GossipStreamServer ) error {
@@ -518,7 +535,8 @@ type stream interface {
518
535
grpc.Stream
519
536
}
520
537
521
- func createGRPCLayer (port int ) (* grpc.Server , net.Listener , grpc.DialOption ) {
538
+ func createGRPCLayer (port int ) (* grpc.Server , net.Listener , grpc.DialOption , []byte ) {
539
+ var returnedCertHash []byte
522
540
var s * grpc.Server
523
541
var ll net.Listener
524
542
var err error
@@ -533,10 +551,25 @@ func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) {
533
551
534
552
err = generateCertificates (keyFileName , certFileName )
535
553
if err == nil {
536
- var creds credentials.TransportCredentials
537
- creds , err = credentials .NewServerTLSFromFile (certFileName , keyFileName )
538
- serverOpts = append (serverOpts , grpc .Creds (creds ))
554
+ cert , err := tls .LoadX509KeyPair (certFileName , keyFileName )
555
+ if err != nil {
556
+ panic (err )
557
+ }
558
+
559
+ if len (cert .Certificate ) == 0 {
560
+ panic (fmt .Errorf ("Certificate chain is nil" ))
561
+ }
562
+
563
+ returnedCertHash = certHashFromRawCert (cert .Certificate [0 ])
564
+
565
+ tlsConf := & tls.Config {
566
+ Certificates : []tls.Certificate {cert },
567
+ ClientAuth : tls .RequestClientCert ,
568
+ InsecureSkipVerify : true ,
569
+ }
570
+ serverOpts = append (serverOpts , grpc .Creds (credentials .NewTLS (tlsConf )))
539
571
ta := credentials .NewTLS (& tls.Config {
572
+ Certificates : []tls.Certificate {cert },
540
573
InsecureSkipVerify : true ,
541
574
})
542
575
dialOpts = grpc .WithTransportCredentials (& authCreds {tlsCreds : ta })
@@ -551,5 +584,5 @@ func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) {
551
584
}
552
585
553
586
s = grpc .NewServer (serverOpts ... )
554
- return s , ll , dialOpts
587
+ return s , ll , dialOpts , returnedCertHash
555
588
}
0 commit comments