diff --git a/weed/cluster/cluster.go b/weed/cluster/cluster.go index 6c24df44c..ad6e6b879 100644 --- a/weed/cluster/cluster.go +++ b/weed/cluster/cluster.go @@ -46,8 +46,6 @@ func NewCluster() *Cluster { } func (cluster *Cluster) getFilers(filerGroup FilerGroup, createIfNotFound bool) *Filers { - cluster.filersLock.Lock() - defer cluster.filersLock.Unlock() filers, found := cluster.filerGroup2filers[filerGroup] if !found && createIfNotFound { filers = &Filers{ @@ -63,6 +61,8 @@ func (cluster *Cluster) AddClusterNode(ns, nodeType string, address pb.ServerAdd filerGroup := FilerGroup(ns) switch nodeType { case FilerType: + cluster.filersLock.Lock() + defer cluster.filersLock.Unlock() filers := cluster.getFilers(filerGroup, true) if existingNode, found := filers.filers[address]; found { existingNode.counter++ @@ -115,6 +115,8 @@ func (cluster *Cluster) RemoveClusterNode(ns string, nodeType string, address pb filerGroup := FilerGroup(ns) switch nodeType { case FilerType: + cluster.filersLock.Lock() + defer cluster.filersLock.Unlock() filers := cluster.getFilers(filerGroup, false) if filers == nil { return nil @@ -165,12 +167,12 @@ func (cluster *Cluster) RemoveClusterNode(ns string, nodeType string, address pb func (cluster *Cluster) ListClusterNode(filerGroup FilerGroup, nodeType string) (nodes []*ClusterNode) { switch nodeType { case FilerType: + cluster.filersLock.RLock() + defer cluster.filersLock.RUnlock() filers := cluster.getFilers(filerGroup, false) if filers == nil { return } - cluster.filersLock.RLock() - defer cluster.filersLock.RUnlock() for _, node := range filers.filers { nodes = append(nodes, node) } diff --git a/weed/cluster/cluster_test.go b/weed/cluster/cluster_test.go index ccaccf6f7..1187642de 100644 --- a/weed/cluster/cluster_test.go +++ b/weed/cluster/cluster_test.go @@ -3,6 +3,8 @@ package cluster import ( "github.com/chrislusf/seaweedfs/weed/pb" "github.com/stretchr/testify/assert" + "strconv" + "sync" "testing" ) @@ -45,3 +47,35 @@ func TestClusterAddRemoveNodes(t *testing.T) { c.RemoveClusterNode("", "filer", pb.ServerAddress("111:1")) } + +func TestConcurrentAddRemoveNodes(t *testing.T) { + c := NewCluster() + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + address := strconv.Itoa(i) + c.AddClusterNode("", "filer", pb.ServerAddress(address), "23.45") + }(i) + } + wg.Wait() + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + address := strconv.Itoa(i) + node := c.RemoveClusterNode("", "filer", pb.ServerAddress(address)) + + if len(node) == 0 { + t.Errorf("TestConcurrentAddRemoveNodes: node[%s] not found", address) + return + } else if node[0].ClusterNodeUpdate.Address != address { + t.Errorf("TestConcurrentAddRemoveNodes: expect:%s, actual:%s", address, node[0].ClusterNodeUpdate.Address) + return + } + }(i) + } + wg.Wait() +}