Skip to content

Commit

Permalink
fix get_all_models
Browse files Browse the repository at this point in the history
  • Loading branch information
semyon-dev committed Feb 20, 2025
1 parent e203f3f commit 4cf3ae9
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 129 deletions.
5 changes: 4 additions & 1 deletion handler/unary_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ func (interceptor *paymentValidationUnaryInterceptor) unaryIntercept(ctx context

ctx = context.WithValue(ctx, "method", info.FullMethod)

lastSlash := strings.LastIndex(info.FullMethod, "/")
methodName := info.FullMethod[lastSlash+1:]

// pass non training requests and free requests
if !strings.Contains(info.FullMethod, "validate_model") && !strings.Contains(info.FullMethod, "train_model") {
if methodName != "validate_model" && methodName != "train_model" {
resp, e := handler(ctx, req)
if e != nil {
zap.L().Warn("gRPC handler returned error", zap.Error(e))
Expand Down
7 changes: 1 addition & 6 deletions training/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,12 +868,7 @@ func (ds *DaemonService) GetAllModels(ctx context.Context, request *AllModelsReq

if request.IsPublic == nil || !*request.IsPublic {

userModelKey := &ModelUserKey{
OrganizationId: config.GetString(config.OrganizationId),
ServiceId: config.GetString(config.ServiceId),
GroupId: ds.organizationMetaData.GetGroupIdString(),
UserAddress: request.Authorization.SignerAddress,
}
userModelKey := ds.userStorage.buildModelUserKey(request.Authorization.SignerAddress)

if data, ok, err := ds.userStorage.Get(userModelKey); data != nil && ok && err == nil {
modelKey := &ModelKey{
Expand Down
120 changes: 0 additions & 120 deletions training/service_mock/test_provider_service.go

This file was deleted.

3 changes: 1 addition & 2 deletions training/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/ecdsa"
"github.com/singnet/snet-daemon/v5/training/service_mock"
"math/big"
"slices"
"strings"
Expand Down Expand Up @@ -81,7 +80,7 @@ func (suite *DaemonServiceSuite) SetupSuite() {

// setup test poriver service once
address := "localhost:5001"
suite.grpcServer = service_mock.StartTestService(address)
suite.grpcServer = startTestService(address)
}

func (suite *DaemonServiceSuite) SetupTest() {
Expand Down
119 changes: 119 additions & 0 deletions training/test_provider_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package training

import (
"context"
"fmt"
"log"
"net"

"go.uber.org/zap"
"google.golang.org/grpc"
)

type model struct {
modelId string
name string
desc string
grpcMethodName string
grpcServiceName string
addressList []string
isPublic bool
serviceId string
groupId string
status Status
}

func startTestService(address string) *grpc.Server {
lis, err := net.Listen("tcp", address)
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

grpcServer := grpc.NewServer()
var trainingServer TestTrainServer
RegisterModelServer(grpcServer, &trainingServer)

trainingServer.curModelId = 0

go func() {
zap.L().Info("Starting test service", zap.String("address", address))
if err := grpcServer.Serve(lis); err != nil {
zap.L().Fatal("Error in starting grpcServer", zap.Error(err))
}
}()

return grpcServer
}

type TestTrainServer struct {
UnimplementedModelServer
curModelId int
models []model
}

func (s *TestTrainServer) CreateModel(ctx context.Context, newModel *NewModel) (*ModelID, error) {
modelIdStr := fmt.Sprintf("%v", s.curModelId)
createdModel := &model{
modelId: modelIdStr,
name: newModel.Name,
desc: newModel.Description,
grpcMethodName: newModel.GrpcMethodName,
grpcServiceName: newModel.GrpcServiceName,
addressList: newModel.AddressList,
isPublic: newModel.IsPublic,
serviceId: newModel.ServiceId,
groupId: newModel.GroupId,
status: Status_CREATED,
}
s.models = append(s.models, *createdModel)

s.curModelId += 1

return &ModelID{
ModelId: fmt.Sprintf("%v", s.curModelId),
}, nil
}

func (s *TestTrainServer) ValidateModelPrice(ctx context.Context, request *ValidateRequest) (*PriceInBaseUnit, error) {
return &PriceInBaseUnit{
Price: 1,
}, nil
}

func (s *TestTrainServer) UploadAndValidate(server Model_UploadAndValidateServer) error {
panic("implement me")
}

func (s *TestTrainServer) ValidateModel(ctx context.Context, request *ValidateRequest) (*StatusResponse, error) {
return &StatusResponse{
Status: Status_VALIDATING,
}, nil
}

func (s *TestTrainServer) TrainModelPrice(ctx context.Context, id *ModelID) (*PriceInBaseUnit, error) {
return &PriceInBaseUnit{
Price: 1,
}, nil
}

func (s *TestTrainServer) TrainModel(ctx context.Context, id *ModelID) (*StatusResponse, error) {
return &StatusResponse{
Status: Status_TRAINING,
}, nil
}

func (s *TestTrainServer) DeleteModel(ctx context.Context, id *ModelID) (*StatusResponse, error) {
return &StatusResponse{
Status: Status_DELETED,
}, nil
}

func (s *TestTrainServer) GetModelStatus(ctx context.Context, id *ModelID) (*StatusResponse, error) {
return &StatusResponse{
Status: Status_VALIDATED,
}, nil
}

func (s *TestTrainServer) mustEmbedUnimplementedModelServer() {
panic("implement me")
}

0 comments on commit 4cf3ae9

Please sign in to comment.