diff --git a/handler/unary_interceptor.go b/handler/unary_interceptor.go index 4db744ad..7fd22c5b 100644 --- a/handler/unary_interceptor.go +++ b/handler/unary_interceptor.go @@ -42,8 +42,11 @@ func (interceptor *paymentValidationUnaryInterceptor) unaryIntercept(ctx context ctx = context.WithValue(ctx, "method", info.FullMethod) + lastSlash := strings.LastIndex(info.FullMethod, "/") + methodName := info.FullMethod[lastSlash+1:] + // pass non training requests and free requests - if !strings.Contains(info.FullMethod, "validate_model") && !strings.Contains(info.FullMethod, "train_model") { + if methodName != "validate_model" && methodName != "train_model" { resp, e := handler(ctx, req) if e != nil { zap.L().Warn("gRPC handler returned error", zap.Error(e)) diff --git a/training/service.go b/training/service.go index edab529f..6e5dc7c1 100644 --- a/training/service.go +++ b/training/service.go @@ -868,12 +868,7 @@ func (ds *DaemonService) GetAllModels(ctx context.Context, request *AllModelsReq if request.IsPublic == nil || !*request.IsPublic { - userModelKey := &ModelUserKey{ - OrganizationId: config.GetString(config.OrganizationId), - ServiceId: config.GetString(config.ServiceId), - GroupId: ds.organizationMetaData.GetGroupIdString(), - UserAddress: request.Authorization.SignerAddress, - } + userModelKey := ds.userStorage.buildModelUserKey(request.Authorization.SignerAddress) if data, ok, err := ds.userStorage.Get(userModelKey); data != nil && ok && err == nil { modelKey := &ModelKey{ diff --git a/training/service_mock/test_provider_service.go b/training/service_mock/test_provider_service.go deleted file mode 100644 index 4dbfc6d7..00000000 --- a/training/service_mock/test_provider_service.go +++ /dev/null @@ -1,120 +0,0 @@ -package service_mock - -import ( - "context" - "fmt" - "log" - "net" - - "github.com/singnet/snet-daemon/v5/training" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -type model struct { - modelId string - name string - desc string - grpcMethodName string - grpcServiceName string - addressList []string - isPublic bool - serviceId string - groupId string - status training.Status -} - -func StartTestService(address string) *grpc.Server { - lis, err := net.Listen("tcp", address) - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - - grpcServer := grpc.NewServer() - var trainingServer TrainServer - training.RegisterModelServer(grpcServer, &trainingServer) - - trainingServer.curModelId = 0 - - go func() { - zap.L().Info("Starting test service", zap.String("address", address)) - if err := grpcServer.Serve(lis); err != nil { - zap.L().Fatal("Error in starting grpcServer", zap.Error(err)) - } - }() - - return grpcServer -} - -type TrainServer struct { - training.UnimplementedModelServer - curModelId int - models []model -} - -func (s *TrainServer) CreateModel(ctx context.Context, newModel *training.NewModel) (*training.ModelID, error) { - modelIdStr := fmt.Sprintf("%v", s.curModelId) - createdModel := &model{ - modelId: modelIdStr, - name: newModel.Name, - desc: newModel.Description, - grpcMethodName: newModel.GrpcMethodName, - grpcServiceName: newModel.GrpcServiceName, - addressList: newModel.AddressList, - isPublic: newModel.IsPublic, - serviceId: newModel.ServiceId, - groupId: newModel.GroupId, - status: training.Status_CREATED, - } - s.models = append(s.models, *createdModel) - - s.curModelId += 1 - - return &training.ModelID{ - ModelId: fmt.Sprintf("%v", s.curModelId), - }, nil -} - -func (s *TrainServer) ValidateModelPrice(ctx context.Context, request *training.ValidateRequest) (*training.PriceInBaseUnit, error) { - return &training.PriceInBaseUnit{ - Price: 1, - }, nil -} - -func (s *TrainServer) UploadAndValidate(server training.Model_UploadAndValidateServer) error { - panic("implement me") -} - -func (s *TrainServer) ValidateModel(ctx context.Context, request *training.ValidateRequest) (*training.StatusResponse, error) { - return &training.StatusResponse{ - Status: training.Status_VALIDATING, - }, nil -} - -func (s *TrainServer) TrainModelPrice(ctx context.Context, id *training.ModelID) (*training.PriceInBaseUnit, error) { - return &training.PriceInBaseUnit{ - Price: 1, - }, nil -} - -func (s *TrainServer) TrainModel(ctx context.Context, id *training.ModelID) (*training.StatusResponse, error) { - return &training.StatusResponse{ - Status: training.Status_TRAINING, - }, nil -} - -func (s *TrainServer) DeleteModel(ctx context.Context, id *training.ModelID) (*training.StatusResponse, error) { - return &training.StatusResponse{ - Status: training.Status_DELETED, - }, nil -} - -func (s *TrainServer) GetModelStatus(ctx context.Context, id *training.ModelID) (*training.StatusResponse, error) { - return &training.StatusResponse{ - Status: training.Status_VALIDATED, - }, nil -} - -func (s *TrainServer) mustEmbedUnimplementedModelServer() { - panic("implement me") -} diff --git a/training/service_test.go b/training/service_test.go index 49f52891..9dbe3bbc 100644 --- a/training/service_test.go +++ b/training/service_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/ecdsa" - "github.com/singnet/snet-daemon/v5/training/service_mock" "math/big" "slices" "strings" @@ -81,7 +80,7 @@ func (suite *DaemonServiceSuite) SetupSuite() { // setup test poriver service once address := "localhost:5001" - suite.grpcServer = service_mock.StartTestService(address) + suite.grpcServer = startTestService(address) } func (suite *DaemonServiceSuite) SetupTest() { diff --git a/training/test_provider_service.go b/training/test_provider_service.go new file mode 100644 index 00000000..01cb3bbe --- /dev/null +++ b/training/test_provider_service.go @@ -0,0 +1,119 @@ +package training + +import ( + "context" + "fmt" + "log" + "net" + + "go.uber.org/zap" + "google.golang.org/grpc" +) + +type model struct { + modelId string + name string + desc string + grpcMethodName string + grpcServiceName string + addressList []string + isPublic bool + serviceId string + groupId string + status Status +} + +func startTestService(address string) *grpc.Server { + lis, err := net.Listen("tcp", address) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + grpcServer := grpc.NewServer() + var trainingServer TestTrainServer + RegisterModelServer(grpcServer, &trainingServer) + + trainingServer.curModelId = 0 + + go func() { + zap.L().Info("Starting test service", zap.String("address", address)) + if err := grpcServer.Serve(lis); err != nil { + zap.L().Fatal("Error in starting grpcServer", zap.Error(err)) + } + }() + + return grpcServer +} + +type TestTrainServer struct { + UnimplementedModelServer + curModelId int + models []model +} + +func (s *TestTrainServer) CreateModel(ctx context.Context, newModel *NewModel) (*ModelID, error) { + modelIdStr := fmt.Sprintf("%v", s.curModelId) + createdModel := &model{ + modelId: modelIdStr, + name: newModel.Name, + desc: newModel.Description, + grpcMethodName: newModel.GrpcMethodName, + grpcServiceName: newModel.GrpcServiceName, + addressList: newModel.AddressList, + isPublic: newModel.IsPublic, + serviceId: newModel.ServiceId, + groupId: newModel.GroupId, + status: Status_CREATED, + } + s.models = append(s.models, *createdModel) + + s.curModelId += 1 + + return &ModelID{ + ModelId: fmt.Sprintf("%v", s.curModelId), + }, nil +} + +func (s *TestTrainServer) ValidateModelPrice(ctx context.Context, request *ValidateRequest) (*PriceInBaseUnit, error) { + return &PriceInBaseUnit{ + Price: 1, + }, nil +} + +func (s *TestTrainServer) UploadAndValidate(server Model_UploadAndValidateServer) error { + panic("implement me") +} + +func (s *TestTrainServer) ValidateModel(ctx context.Context, request *ValidateRequest) (*StatusResponse, error) { + return &StatusResponse{ + Status: Status_VALIDATING, + }, nil +} + +func (s *TestTrainServer) TrainModelPrice(ctx context.Context, id *ModelID) (*PriceInBaseUnit, error) { + return &PriceInBaseUnit{ + Price: 1, + }, nil +} + +func (s *TestTrainServer) TrainModel(ctx context.Context, id *ModelID) (*StatusResponse, error) { + return &StatusResponse{ + Status: Status_TRAINING, + }, nil +} + +func (s *TestTrainServer) DeleteModel(ctx context.Context, id *ModelID) (*StatusResponse, error) { + return &StatusResponse{ + Status: Status_DELETED, + }, nil +} + +func (s *TestTrainServer) GetModelStatus(ctx context.Context, id *ModelID) (*StatusResponse, error) { + return &StatusResponse{ + Status: Status_VALIDATED, + }, nil +} + +func (s *TestTrainServer) mustEmbedUnimplementedModelServer() { + panic("implement me") +}