@@ -18,6 +18,8 @@ package comm
18
18
19
19
import (
20
20
"bytes"
21
+ "crypto/hmac"
22
+ "crypto/sha256"
21
23
"crypto/tls"
22
24
"fmt"
23
25
"math/rand"
@@ -49,7 +51,10 @@ func acceptAll(msg interface{}) bool {
49
51
return true
50
52
}
51
53
52
- var naiveSec = & naiveSecProvider {}
54
+ var (
55
+ naiveSec = & naiveSecProvider {}
56
+ hmacKey = []byte {0 , 0 , 0 }
57
+ )
53
58
54
59
type naiveSecProvider struct {
55
60
}
@@ -72,15 +77,19 @@ func (*naiveSecProvider) VerifyBlock(chainID common.ChainID, signedBlock []byte)
72
77
// Sign signs msg with this peer's signing key and outputs
73
78
// the signature if no error occurred.
74
79
func (* naiveSecProvider ) Sign (msg []byte ) ([]byte , error ) {
75
- return msg , nil
80
+ mac := hmac .New (sha256 .New , hmacKey )
81
+ mac .Write (msg )
82
+ return mac .Sum (nil ), nil
76
83
}
77
84
78
85
// Verify checks that signature is a valid signature of message under a peer's verification key.
79
86
// If the verification succeeded, Verify returns nil meaning no error occurred.
80
87
// If peerCert is nil, then the signature is verified against this peer's verification key.
81
88
func (* naiveSecProvider ) Verify (peerIdentity api.PeerIdentityType , signature , message []byte ) error {
82
- equal := bytes .Equal (signature , message )
83
- if ! equal {
89
+ mac := hmac .New (sha256 .New , hmacKey )
90
+ mac .Write (message )
91
+ expected := mac .Sum (nil )
92
+ if ! bytes .Equal (signature , expected ) {
84
93
return fmt .Errorf ("Wrong certificate:%v, %v" , signature , message )
85
94
}
86
95
return nil
@@ -98,16 +107,22 @@ func newCommInstance(port int, sec api.MessageCryptoService) (Comm, error) {
98
107
return inst , err
99
108
}
100
109
101
- func handshaker (endpoint string , comm Comm , t * testing.T , sigMutator func ([]byte ) []byte , pkiIDmutator func ([]byte ) []byte ) <- chan proto.ReceivedMessage {
110
+ func handshaker (endpoint string , comm Comm , t * testing.T , sigMutator func ([]byte ) []byte , pkiIDmutator func ([]byte ) []byte , mutualTLS bool ) <- chan proto.ReceivedMessage {
102
111
c := & commImpl {}
103
112
err := generateCertificates ("key.pem" , "cert.pem" )
104
113
defer os .Remove ("cert.pem" )
105
114
defer os .Remove ("key.pem" )
106
115
cert , err := tls .LoadX509KeyPair ("cert.pem" , "key.pem" )
107
- ta := credentials . NewTLS ( & tls.Config {
116
+ tlsCfg := & tls.Config {
108
117
InsecureSkipVerify : true ,
109
- Certificates : []tls.Certificate {cert },
110
- })
118
+ }
119
+
120
+ if mutualTLS {
121
+ tlsCfg .Certificates = []tls.Certificate {cert }
122
+ }
123
+
124
+ ta := credentials .NewTLS (tlsCfg )
125
+
111
126
acceptChan := comm .Accept (acceptAll )
112
127
conn , err := grpc .Dial ("localhost:9611" , grpc .WithTransportCredentials (& authCreds {tlsCreds : ta }), grpc .WithBlock (), grpc .WithTimeout (time .Second ))
113
128
assert .NoError (t , err , "%v" , err )
@@ -119,16 +134,25 @@ func handshaker(endpoint string, comm Comm, t *testing.T, sigMutator func([]byte
119
134
assert .NoError (t , err , "%v" , err )
120
135
if err != nil {
121
136
return nil
137
+ } // cert.Certificate[0]
138
+
139
+ var clientCertHash []byte
140
+ if mutualTLS {
141
+ clientCertHash = certHashFromRawCert (tlsCfg .Certificates [0 ].Certificate [0 ])
122
142
}
123
- clientCertHash := certHashFromRawCert (cert .Certificate [0 ])
124
143
125
144
pkiID := common .PKIidType (endpoint )
126
145
if pkiIDmutator != nil {
127
146
pkiID = common .PKIidType (pkiIDmutator ([]byte (endpoint )))
128
147
}
129
148
assert .NoError (t , err , "%v" , err )
130
149
msg := c .createConnectionMsg (pkiID , clientCertHash , []byte (endpoint ), func (msg []byte ) ([]byte , error ) {
131
- return msg , nil
150
+ if ! mutualTLS {
151
+ return msg , nil
152
+ }
153
+ mac := hmac .New (sha256 .New , hmacKey )
154
+ mac .Write (msg )
155
+ return mac .Sum (nil ), nil
132
156
})
133
157
134
158
if sigMutator != nil {
@@ -143,9 +167,14 @@ func handshaker(endpoint string, comm Comm, t *testing.T, sigMutator func([]byte
143
167
if sigMutator == nil {
144
168
hash := extractCertificateHashFromContext (stream .Context ())
145
169
expectedMsg := c .createConnectionMsg (common .PKIidType ("localhost:9611" ), hash , []byte ("localhost:9611" ), func (msg []byte ) ([]byte , error ) {
146
- return msg , nil
170
+ mac := hmac .New (sha256 .New , hmacKey )
171
+ mac .Write (msg )
172
+ return mac .Sum (nil ), nil
147
173
})
148
- assert .Equal (t , expectedMsg .Envelope .Signature , msg .Envelope .Signature )
174
+ if mutualTLS {
175
+ assert .Equal (t , expectedMsg .Envelope .Signature , msg .Envelope .Signature )
176
+ }
177
+
149
178
}
150
179
assert .Equal (t , []byte ("localhost:9611" ), msg .GetConn ().PkiId )
151
180
msg2Send := createGossipMsg ()
@@ -177,7 +206,7 @@ func TestHandshake(t *testing.T) {
177
206
comm , _ := newCommInstance (9611 , naiveSec )
178
207
defer comm .Stop ()
179
208
180
- acceptChan := handshaker ("localhost:9610" , comm , t , nil , nil )
209
+ acceptChan := handshaker ("localhost:9610" , comm , t , nil , nil , true )
181
210
time .Sleep (2 * time .Second )
182
211
assert .Equal (t , 1 , len (acceptChan ))
183
212
msg := <- acceptChan
@@ -186,7 +215,8 @@ func TestHandshake(t *testing.T) {
186
215
assert .Equal (t , api .PeerIdentityType ("localhost:9610" ), msg .GetConnectionInfo ().Identity )
187
216
assert .NotNil (t , msg .GetConnectionInfo ().Auth )
188
217
assert .True (t , msg .GetConnectionInfo ().IsAuthenticated ())
189
- assert .Equal (t , msg .GetConnectionInfo ().Auth .Signature , msg .GetConnectionInfo ().Auth .SignedData )
218
+ sig , _ := (& naiveSecProvider {}).Sign (msg .GetConnectionInfo ().Auth .SignedData )
219
+ assert .Equal (t , sig , msg .GetConnectionInfo ().Auth .Signature )
190
220
// negative path, nothing should be read from the channel because the signature is wrong
191
221
mutateSig := func (b []byte ) []byte {
192
222
if b [0 ] == 0 {
@@ -196,17 +226,36 @@ func TestHandshake(t *testing.T) {
196
226
}
197
227
return b
198
228
}
199
- acceptChan = handshaker ("localhost:9612" , comm , t , mutateSig , nil )
229
+ acceptChan = handshaker ("localhost:9612" , comm , t , mutateSig , nil , true )
200
230
time .Sleep (time .Second )
201
231
assert .Equal (t , 0 , len (acceptChan ))
202
232
203
233
// negative path, nothing should be read from the channel because the PKIid doesn't match the identity
204
234
mutatePKIID := func (b []byte ) []byte {
205
235
return []byte ("localhost:9650" )
206
236
}
207
- acceptChan = handshaker ("localhost:9613" , comm , t , nil , mutatePKIID )
237
+ acceptChan = handshaker ("localhost:9613" , comm , t , nil , mutatePKIID , true )
208
238
time .Sleep (time .Second )
209
239
assert .Equal (t , 0 , len (acceptChan ))
240
+
241
+ // Now we test for a handshake without mutual TLS
242
+ // The first time should fail
243
+ acceptChan = handshaker ("localhost:9614" , comm , t , nil , nil , false )
244
+ select {
245
+ case <- acceptChan :
246
+ assert .Fail (t , "Should not have successfully authenticated to remote peer" )
247
+ case <- time .After (time .Second ):
248
+ }
249
+
250
+ // And the second time should succeed
251
+ comm .(* commImpl ).skipHandshake = true
252
+ acceptChan = handshaker ("localhost:9615" , comm , t , nil , nil , false )
253
+ select {
254
+ case <- acceptChan :
255
+ case <- time .After (time .Second * 10 ):
256
+ assert .Fail (t , "skipHandshake flag should have authorized the authentication" )
257
+ }
258
+
210
259
}
211
260
212
261
func TestBasic (t * testing.T ) {
0 commit comments