Skip to content

Commit

Permalink
changed variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsupe committed Jul 31, 2024
1 parent feaf94f commit ac56e05
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 61 deletions.
56 changes: 10 additions & 46 deletions modules/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,16 @@ type VectorModel interface {
GetEmbedding(token string) interface{}
}

// Word2VecModel represents a 32-bit floating point Word2Vec model
type Word2VecModel struct {
// VecModel32bit represents a 32-bit floating point Word2Vec model
type VecModel32bit struct {
Vectors map[string][]float32
Size int
}

// // LoadModel loads a 32-bit floating point Word2Vec model from a file
// func (m *Word2VecModel) LoadModel(filename string) error {
// file, err := os.Open(filename)
// if err != nil {
// return fmt.Errorf("failed to open file: %v", err)
// }
// defer file.Close()

// reader := bufio.NewReader(file)

// var vocabSize, vectorSize int
// fmt.Fscanf(reader, "%d %d\n", &vocabSize, &vectorSize)

// m.Vectors = make(map[string][]float32, vocabSize)
// m.Size = vectorSize

// for i := 0; i < vocabSize; i++ {
// word, err := reader.ReadString(' ')
// if err != nil {
// return fmt.Errorf("failed to read word: %v", err)
// }
// word = strings.TrimSpace(word)

// vector := make([]float32, vectorSize)
// for j := 0; j < vectorSize; j++ {
// err := binary.Read(reader, binary.LittleEndian, &vector[j])
// if err != nil {
// return fmt.Errorf("failed to read vector: %v", err)
// }
// }
// m.Vectors[word] = vector
// }

// return nil
// }

// LoadModel loads a 32-bit floating point Word2Vec model from a file
// Attempt to validate the header and check for unexpected data
// at the end of each record and at the end of the file
func (m *Word2VecModel) LoadModel(filename string) error {
func (m *VecModel32bit) LoadModel(filename string) error {
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
Expand Down Expand Up @@ -127,24 +91,24 @@ func (m *Word2VecModel) LoadModel(filename string) error {
}

// GetEmbedding returns the vector embedding of a token for the 32-bit model
func (m *Word2VecModel) GetEmbedding(token string) interface{} {
func (m *VecModel32bit) GetEmbedding(token string) interface{} {
vec, ok := m.Vectors[token]
if !ok {
return make([]float32, m.Size)
}
return vec
}

// QuantizedWord2VecModel represents an 8-bit integer quantized Word2Vec model
type QuantizedWord2VecModel struct {
// VecModel8bit represents an 8-bit integer quantized Word2Vec model
type VecModel8bit struct {
Vectors map[string][]int8
Min float32
Max float32
Size int
}

// LoadModel loads an 8-bit integer quantized Word2Vec model from a file
func (m *QuantizedWord2VecModel) LoadModel(filename string) error {
func (m *VecModel8bit) LoadModel(filename string) error {
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
Expand Down Expand Up @@ -187,7 +151,7 @@ func (m *QuantizedWord2VecModel) LoadModel(filename string) error {
}

// GetEmbedding returns the vector embedding of a token for the 8-bit quantized model
func (m *QuantizedWord2VecModel) GetEmbedding(token string) interface{} {
func (m *VecModel8bit) GetEmbedding(token string) interface{} {
vec, ok := m.Vectors[token]
if !ok {
return make([]int8, m.Size)
Expand Down Expand Up @@ -217,9 +181,9 @@ func LoadVectorModel(filename string) (VectorModel, error) {
var model VectorModel

if strings.HasSuffix(filename, ".bin") {
model = &Word2VecModel{}
model = &VecModel32bit{}
} else if strings.HasSuffix(filename, ".8int.bin") {
model = &QuantizedWord2VecModel{}
model = &VecModel8bit{}
} else {
return nil, fmt.Errorf("unsupported file format")
}
Expand Down
15 changes: 0 additions & 15 deletions sgrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

func main() {
modelPath := flag.String("model_path", "", "Path to the Word2Vec model file")
// quantized := flag.Bool("q", false, "Use quantized model")
similarityThreshold := flag.Float64("threshold", 0.7, "Similarity threshold for matching")
contextBefore := flag.Int("A", 0, "Number of lines before matching line")
contextAfter := flag.Int("B", 0, "Number of lines after matching line")
Expand Down Expand Up @@ -78,20 +77,6 @@ func main() {
var w2vModel model.VectorModel
var similarityCache similarity.SimilarityCache

// if *quantized {
// w2vModel, err = model.LoadQuantizedModel(*modelPath)
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error loading quantized model: %v\n", err)
// os.Exit(1)
// }
// } else {
// w2vModel, err = model.LoadWord2VecModel(*modelPath)
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error loading full model: %v\n", err)
// os.Exit(1)
// }
// }

w2vModel, err = model.LoadVectorModel(*modelPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading full model: %v\n", err)
Expand Down

0 comments on commit ac56e05

Please sign in to comment.