func()

in vermeer/apps/worker/compute_bl.go [137:297]


func (cb *ComputeBl) RunSuperStep(taskID int32, computeValues map[string][]byte) {
	defer func() {
		if r := recover(); r != nil {
			cb.SetStatusError(taskID, fmt.Sprintf("RunSuperStep panic recover panic:%v, stack message: %s", r,
				common.GetCurrentGoroutineStack()))
			logrus.Errorf("RunSuperStep panic recover taskID:%v, panic:%v, stack message: %s", taskID, r,
				common.GetCurrentGoroutineStack())
		}
	}()
	computeTask := ComputeTaskMgr.GetTask(taskID)
	if !cb.CheckAction(computeTask) {
		return
	}
	computeTask.Step += 1
	ctx := computeTask.ComputeWorker.Context()
	ctx.Step += 1
	if computeValues != nil {
		ctx.UnmarshalValues(computeValues)
	}
	computeTask.Task.SetState(structure.TaskStateStepDoing)

	logrus.Infof("RunSuperStep start step: %d", computeTask.Step)
	computeTask.ComputeWorker.BeforeStep()
	computeTask.StepWg.Done()

	workerCount := len(computeTask.Task.Workers)
	peers := make([]*PeerClient, 0, workerCount-1)
	for _, wn := range computeTask.Task.Workers {
		if wn.Name == ServiceWorker.WorkerName {
			continue
		}
		peers = append(peers, PeerMgr.GetPeer(wn.Name))
	}

	*computeTask.Parallel = int32(ctx.Parallel) * int32(len(peers))
	computeTask.RecvWg.Add(int32(ctx.Parallel) * int32(len(peers)))

	parallel := ctx.Parallel
	partCnt := int(ctx.GraphData.VertexCount)/parallel + 1
	wg := sync.WaitGroup{}
	for i := 0; i < parallel; i++ {
		wg.Add(1)
		go func(pId int) {
			defer func() {
				if r := recover(); r != nil {
					cb.SetStatusError(taskID, fmt.Sprintf("RunSuperStep panic recover panic:%v, stack message: %s",
						r, common.GetCurrentGoroutineStack()))
					logrus.Errorf("RunSuperStep panic recover taskID:%v, pId:%v panic:%v, stack message: %s",
						taskID, pId, r, common.GetCurrentGoroutineStack())
				}
			}()
			defer wg.Done()
			defer func() {
				for _, peer := range peers {
					peer.ScatterHandler.SendScatter(
						taskID,
						ctx.Step,
						0,
						true,
						0,
						[]byte{})
				}
			}()
			bIdx := uint32(partCnt*pId) + ctx.GraphData.VertIDStart
			eIdx := bIdx + uint32(partCnt)
			if eIdx > ctx.GraphData.VertIDStart+ctx.GraphData.VertexCount {
				eIdx = ctx.GraphData.VertIDStart + ctx.GraphData.VertexCount
			}
			vOffSet := serialize.SUint32(bIdx)
			for j := bIdx; j < eIdx; j++ {
				if j%10000 == 0 && !cb.CheckAction(computeTask) {
					//减少 check action 次数,降低影响
					return
				}
				computeTask.ComputeWorker.Compute(j, pId)
			}

			if len(peers) > 0 {
				_ = ctx.SendBuffers[pId].Marshal(&vOffSet)
				for j := bIdx; j < eIdx; j++ {
					if j%10000 == 0 && !cb.CheckAction(computeTask) {
						//减少 check action 次数,降低影响
						return
					}
					err := ctx.SendBuffers[pId].Marshal(computeTask.ComputeWorker.VertexValue(j))
					if err != nil {
						logrus.Errorf("sendbuffer marshal error:%v", err)
						cb.SetStatusError(taskID, fmt.Sprintf("sendbuffer marshal error:%v", err))
					}
					if ctx.SendBuffers[pId].Full() {
						for _, peer := range peers {
							peer.ScatterHandler.SendScatter(
								taskID,
								ctx.Step,
								int32(ctx.SendBuffers[pId].ObjCount()),
								false,
								0,
								ctx.SendBuffers[pId].PayLoad())
						}
						ctx.SendBuffers[pId].Reset()
						vOffSet = serialize.SUint32(j + 1)
						err = ctx.SendBuffers[pId].Marshal(&vOffSet)
						if err != nil {
							logrus.Errorf("sendbuffer marshal error:%v", err)
							cb.SetStatusError(taskID, fmt.Sprintf("sendbuffer marshal error:%v", err))
						}
					}
				}
			}

			for _, peer := range peers {
				peer.ScatterHandler.SendScatter(
					taskID,
					ctx.Step,
					int32(ctx.SendBuffers[pId].ObjCount()),
					false,
					0,
					ctx.SendBuffers[pId].PayLoad())
			}
			ctx.SendBuffers[pId].Reset()

			// TODO: wait for all compute done
			if len(peers) > 0 {
				for sIdx, cnt := range computeTask.ComputeWorker.Scatters() {
					if sIdx%10000 == 0 && !cb.CheckAction(computeTask) {
						//减少 check action 次数,降低影响
						return
					}
					partCnt = cnt/parallel + 1
					start := partCnt * pId
					end := cnt / parallel * (pId + 1)
					if pId+1 == parallel {
						end = computeTask.ComputeWorker.Context().PartStart(cnt) +
							computeTask.ComputeWorker.Context().PartCount(cnt)
					}

					vOffSet = serialize.SUint32(start)
					cb.WaitDone(start, end, pId, taskID, vOffSet, sIdx, peers, ctx, computeTask)

					for _, peer := range peers {
						peer.ScatterHandler.SendScatter(
							taskID,
							ctx.Step,
							int32(ctx.SendBuffers[pId].ObjCount()),
							false,
							int32(sIdx),
							ctx.SendBuffers[pId].PayLoad())
					}
					ctx.SendBuffers[pId].Reset()
				}
			}
		}(i)
	}
	wg.Wait()
	for _, peer := range peers {
		go peer.StepEndHandler.SendStepEnd(taskID)
	}

	// send to self
	go cb.StepEnd(taskID, ServiceWorker.WorkerName)
}