This commit is contained in:
mr
2026-03-17 11:57:22 +01:00
parent edcfecd24b
commit 83285c2ab5
13 changed files with 221 additions and 245 deletions

View File

@@ -51,13 +51,16 @@ var protocolsPartners = map[protocol.ID]*common.ProtocolInfo{
}
type StreamService struct {
Key pp.ID
Host host.Host
Node common.DiscoveryPeer
Streams common.ProtocolStream
maxNodesConn int
Mu sync.RWMutex
Key pp.ID
Host host.Host
Node common.DiscoveryPeer
Streams common.ProtocolStream
maxNodesConn int
Mu sync.RWMutex
ResourceSearches *common.SearchTracker
// IsPeerKnown, when set, is called at stream open for every inbound protocol.
// Return false to reset the stream immediately. Left nil until wired by the node.
IsPeerKnown func(pid pp.ID) bool
}
func InitStream(ctx context.Context, h host.Host, key pp.ID, maxNode int, node common.DiscoveryPeer) (*StreamService, error) {
@@ -71,7 +74,7 @@ func InitStream(ctx context.Context, h host.Host, key pp.ID, maxNode int, node c
ResourceSearches: common.NewSearchTracker(),
}
for proto := range protocols {
service.Host.SetStreamHandler(proto, service.HandleResponse)
service.Host.SetStreamHandler(proto, service.gate(service.HandleResponse))
}
logger.Info().Msg("connect to partners...")
service.connectToPartners() // we set up a stream
@@ -79,6 +82,21 @@ func InitStream(ctx context.Context, h host.Host, key pp.ID, maxNode int, node c
return service, nil
}
// gate wraps a stream handler with IsPeerKnown validation.
// If the peer is unknown the entire connection is closed and the handler is not called.
// IsPeerKnown is read at stream-open time so it works even when set after InitStream.
func (s *StreamService) gate(h func(network.Stream)) func(network.Stream) {
return func(stream network.Stream) {
if s.IsPeerKnown != nil && !s.IsPeerKnown(stream.Conn().RemotePeer()) {
logger := oclib.GetLogger()
logger.Warn().Str("peer", stream.Conn().RemotePeer().String()).Msg("[stream] unknown peer, closing connection")
stream.Conn().Close()
return
}
h(stream)
}
}
func (s *StreamService) HandleResponse(stream network.Stream) {
s.Mu.Lock()
defer s.Mu.Unlock()
@@ -119,7 +137,7 @@ func (s *StreamService) connectToPartners() error {
go s.readLoop(s.Streams[proto][ss.Conn().RemotePeer()], ss.Conn().RemotePeer(), proto, info)
}
logger.Info().Msg("SetStreamHandler " + string(proto))
s.Host.SetStreamHandler(proto, f)
s.Host.SetStreamHandler(proto, s.gate(f))
}
return nil
}