diff options
| author | Loyalsoldier <[email protected]> | 2024-08-11 11:59:21 +0800 |
|---|---|---|
| committer | Loyalsoldier <[email protected]> | 2024-08-11 11:59:21 +0800 |
| commit | c2db35721738e91189cfb8a26fb3e97bcceb000d (patch) | |
| tree | 52f207d6ffc8bc7471f87539b8075df12bf56b71 /plugin | |
| parent | 87dcc81c764c3c28fea6772df4b1352a6a3f89fe (diff) | |
Feat: support mihomo MRS format as input & output
Diffstat (limited to 'plugin')
| -rw-r--r-- | plugin/mihomo/mrs_in.go | 337 | ||||
| -rw-r--r-- | plugin/mihomo/mrs_out.go | 248 |
2 files changed, 585 insertions, 0 deletions
diff --git a/plugin/mihomo/mrs_in.go b/plugin/mihomo/mrs_in.go new file mode 100644 index 00000000..8b5eeefb --- /dev/null +++ b/plugin/mihomo/mrs_in.go @@ -0,0 +1,337 @@ +package mihomo + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "net/netip" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/Loyalsoldier/geoip/lib" + "github.com/klauspost/compress/zstd" + "go4.org/netipx" +) + +var mrsMagicBytes = [4]byte{'M', 'R', 'S', 1} // MRSv1 + +const ( + typeMRSIn = "mihomoMRS" + descMRSIn = "Convert mihomo MRS data to other formats" +) + +func init() { + lib.RegisterInputConfigCreator(typeMRSIn, func(action lib.Action, data json.RawMessage) (lib.InputConverter, error) { + return newMRSIn(action, data) + }) + lib.RegisterInputConverter(typeMRSIn, &mrsIn{ + Description: descMRSIn, + }) +} + +func newMRSIn(action lib.Action, data json.RawMessage) (lib.InputConverter, error) { + var tmp struct { + Name string `json:"name"` + URI string `json:"uri"` + InputDir string `json:"inputDir"` + OnlyIPType lib.IPType `json:"onlyIPType"` + } + + if len(data) > 0 { + if err := json.Unmarshal(data, &tmp); err != nil { + return nil, err + } + } + + if tmp.Name == "" && tmp.URI == "" && tmp.InputDir == "" { + return nil, fmt.Errorf("type %s | action %s missing inputdir or name or uri", typeMRSIn, action) + } + + if (tmp.Name != "" && tmp.URI == "") || (tmp.Name == "" && tmp.URI != "") { + return nil, fmt.Errorf("type %s | action %s name & uri must be specified together", typeMRSIn, action) + } + + return &mrsIn{ + Type: typeMRSIn, + Action: action, + Description: descMRSIn, + Name: tmp.Name, + URI: tmp.URI, + InputDir: tmp.InputDir, + OnlyIPType: tmp.OnlyIPType, + }, nil +} + +type mrsIn struct { + Type string + Action lib.Action + Description string + Name string + URI string + InputDir string + OnlyIPType lib.IPType +} + +func (m *mrsIn) GetType() string { + return m.Type +} + +func (m *mrsIn) GetAction() lib.Action { + return m.Action +} + +func (m *mrsIn) GetDescription() string { + return m.Description +} + +func (m *mrsIn) Input(container lib.Container) (lib.Container, error) { + entries := make(map[string]*lib.Entry) + var err error + + switch { + case m.InputDir != "": + err = m.walkDir(m.InputDir, entries) + case m.Name != "" && m.URI != "": + switch { + case strings.HasPrefix(strings.ToLower(m.URI), "http://"), strings.HasPrefix(strings.ToLower(m.URI), "https://"): + err = m.walkRemoteFile(m.URI, m.Name, entries) + default: + err = m.walkLocalFile(m.URI, m.Name, entries) + } + default: + return nil, fmt.Errorf("config missing argument inputDir or name or uri") + } + + if err != nil { + return nil, err + } + + var ignoreIPType lib.IgnoreIPOption + switch m.OnlyIPType { + case lib.IPv4: + ignoreIPType = lib.IgnoreIPv6 + case lib.IPv6: + ignoreIPType = lib.IgnoreIPv4 + } + + if len(entries) == 0 { + return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", m.Type, m.Action) + } + + for _, entry := range entries { + switch m.Action { + case lib.ActionAdd: + if err := container.Add(entry, ignoreIPType); err != nil { + return nil, err + } + case lib.ActionRemove: + if err := container.Remove(entry, lib.CaseRemovePrefix, ignoreIPType); err != nil { + return nil, err + } + default: + return nil, lib.ErrUnknownAction + } + } + + return container, nil +} + +func (m *mrsIn) walkDir(dir string, entries map[string]*lib.Entry) error { + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + if err := m.walkLocalFile(path, "", entries); err != nil { + return err + } + + return nil + }) + + return err +} + +func (m *mrsIn) walkLocalFile(path, name string, entries map[string]*lib.Entry) error { + entryName := "" + name = strings.TrimSpace(name) + if name != "" { + entryName = name + } else { + entryName = filepath.Base(path) + + // check filename + if !regexp.MustCompile(`^[a-zA-Z0-9_.\-]+$`).MatchString(entryName) { + return fmt.Errorf("filename %s cannot be entry name, please remove special characters in it", entryName) + } + + // remove file extension but not hidden files of which filename starts with "." + dotIndex := strings.LastIndex(entryName, ".") + if dotIndex > 0 { + entryName = entryName[:dotIndex] + } + } + + entryName = strings.ToUpper(entryName) + if _, found := entries[entryName]; found { + return fmt.Errorf("found duplicated list %s", entryName) + } + + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + if err := m.generateEntries(entryName, file, entries); err != nil { + return err + } + + return nil +} + +func (m *mrsIn) walkRemoteFile(url, name string, entries map[string]*lib.Entry) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("failed to get remote file %s, http status code %d", url, resp.StatusCode) + } + + if err := m.generateEntries(name, resp.Body, entries); err != nil { + return err + } + + return nil +} + +func (m *mrsIn) generateEntries(name string, reader io.Reader, entries map[string]*lib.Entry) error { + name = strings.ToUpper(name) + entry, found := entries[name] + if !found { + entry = lib.NewEntry(name) + } + + data, err := io.ReadAll(reader) + if err != nil { + return err + } + + err = m.parseMRS(data, entry) + if err != nil { + return err + } + + entries[name] = entry + return nil +} + +func (m *mrsIn) parseMRS(data []byte, entry *lib.Entry) error { + reader, err := zstd.NewReader(bytes.NewReader(data)) + if err != nil { + return err + } + defer reader.Close() + + // header + var header [4]byte + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return err + } + if header != mrsMagicBytes { + return fmt.Errorf("invalid MRS format") + } + + // behavior + var behavior [1]byte + _, err = io.ReadFull(reader, behavior[:]) + if err != nil { + return err + } + if behavior[0] != byte(1) { // RuleBehavior IPCIDR = 1 + return fmt.Errorf("invalid MRS IPCIDR data") + } + + // count + var count int64 + err = binary.Read(reader, binary.BigEndian, &count) + if err != nil { + return err + } + + // extra (reserved for future using) + var length int64 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + if length < 0 { + return fmt.Errorf("invalid MRS extra length") + } + if length > 0 { + extra := make([]byte, length) + _, err = io.ReadFull(reader, extra) + if err != nil { + return err + } + } + + // + // rules + // + // version + version := make([]byte, 1) + _, err = io.ReadFull(reader, version) + if err != nil { + return err + } + if version[0] != 1 { + return fmt.Errorf("invalid MRS rule version") + } + + // rule length + var ruleLength int64 + err = binary.Read(reader, binary.BigEndian, &ruleLength) + if err != nil { + return err + } + if ruleLength < 1 { + return fmt.Errorf("invalid MRS rule length") + } + + for i := int64(0); i < ruleLength; i++ { + var a16 [16]byte + err = binary.Read(reader, binary.BigEndian, &a16) + if err != nil { + return err + } + from := netip.AddrFrom16(a16).Unmap() + + err = binary.Read(reader, binary.BigEndian, &a16) + if err != nil { + return err + } + to := netip.AddrFrom16(a16).Unmap() + + iprange := netipx.IPRangeFrom(from, to) + for _, prefix := range iprange.Prefixes() { + if err := entry.AddPrefix(prefix); err != nil { + return err + } + } + } + + return nil +} diff --git a/plugin/mihomo/mrs_out.go b/plugin/mihomo/mrs_out.go new file mode 100644 index 00000000..85717b30 --- /dev/null +++ b/plugin/mihomo/mrs_out.go @@ -0,0 +1,248 @@ +package mihomo + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "log" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/Loyalsoldier/geoip/lib" + "github.com/klauspost/compress/zstd" + "go4.org/netipx" +) + +const ( + typeMRSOut = "mihomoMRS" + descMRSOut = "Convert data to mihomo MRS format" +) + +var ( + defaultOutputDir = filepath.Join("./", "output", "mrs") +) + +func init() { + lib.RegisterOutputConfigCreator(typeMRSOut, func(action lib.Action, data json.RawMessage) (lib.OutputConverter, error) { + return newMRSOut(action, data) + }) + lib.RegisterOutputConverter(typeMRSOut, &mrsOut{ + Description: descMRSOut, + }) +} + +func newMRSOut(action lib.Action, data json.RawMessage) (lib.OutputConverter, error) { + var tmp struct { + OutputDir string `json:"outputDir"` + Want []string `json:"wantedList"` + OnlyIPType lib.IPType `json:"onlyIPType"` + } + + if len(data) > 0 { + if err := json.Unmarshal(data, &tmp); err != nil { + return nil, err + } + } + + if tmp.OutputDir == "" { + tmp.OutputDir = defaultOutputDir + } + + return &mrsOut{ + Type: typeMRSOut, + Action: action, + Description: descMRSOut, + OutputDir: tmp.OutputDir, + Want: tmp.Want, + OnlyIPType: tmp.OnlyIPType, + }, nil +} + +type mrsOut struct { + Type string + Action lib.Action + Description string + OutputDir string + Want []string + OnlyIPType lib.IPType +} + +func (m *mrsOut) GetType() string { + return m.Type +} + +func (m *mrsOut) GetAction() lib.Action { + return m.Action +} + +func (m *mrsOut) GetDescription() string { + return m.Description +} + +func (m *mrsOut) Output(container lib.Container) error { + // Filter want list + wantList := make([]string, 0, len(m.Want)) + for _, want := range m.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList = append(wantList, want) + } + } + + switch len(wantList) { + case 0: + list := make([]string, 0, 300) + for entry := range container.Loop() { + list = append(list, entry.GetName()) + } + + // Sort the list + slices.Sort(list) + + for _, name := range list { + entry, found := container.GetEntry(name) + if !found { + log.Printf("❌ entry %s not found", name) + continue + } + if err := m.generate(entry); err != nil { + return err + } + } + + default: + // Sort the list + slices.Sort(wantList) + + for _, name := range wantList { + entry, found := container.GetEntry(name) + if !found { + log.Printf("❌ entry %s not found", name) + continue + } + + if err := m.generate(entry); err != nil { + return err + } + } + } + + return nil +} + +func (m *mrsOut) generate(entry *lib.Entry) error { + var ipRanges []netipx.IPRange + var err error + switch m.OnlyIPType { + case lib.IPv4: + ipRanges, err = entry.MarshalIPRange(lib.IgnoreIPv6) + case lib.IPv6: + ipRanges, err = entry.MarshalIPRange(lib.IgnoreIPv4) + default: + ipRanges, err = entry.MarshalIPRange() + } + if err != nil { + return err + } + + if len(ipRanges) == 0 { + return fmt.Errorf("entry %s has no CIDR", entry.GetName()) + } + + filename := strings.ToLower(entry.GetName()) + ".mrs" + if err := m.writeFile(filename, ipRanges); err != nil { + return err + } + + return nil +} + +func (m *mrsOut) writeFile(filename string, ipRanges []netipx.IPRange) error { + if err := os.MkdirAll(m.OutputDir, 0755); err != nil { + return err + } + + f, err := os.Create(filepath.Join(m.OutputDir, filename)) + if err != nil { + return err + } + defer f.Close() + + err = m.convertToMrs(ipRanges, f) + if err != nil { + return err + } + + log.Printf("✅ [%s] %s --> %s", m.Type, filename, m.OutputDir) + + return nil +} + +func (m *mrsOut) convertToMrs(ipRanges []netipx.IPRange, w io.Writer) (err error) { + encoder, err := zstd.NewWriter(w) + if err != nil { + return err + } + defer encoder.Close() + + // header + _, err = encoder.Write(mrsMagicBytes[:]) + if err != nil { + return err + } + + // behavior + _, err = encoder.Write([]byte{1}) // RuleBehavior IPCIDR = 1 + if err != nil { + return err + } + + // count + count := int64(len(ipRanges)) + err = binary.Write(encoder, binary.BigEndian, count) + if err != nil { + return err + } + + // extra (reserved for future using) + var extra []byte + err = binary.Write(encoder, binary.BigEndian, int64(len(extra))) + if err != nil { + return err + } + _, err = encoder.Write(extra) + if err != nil { + return err + } + + // + // rule + // + // version + _, err = encoder.Write([]byte{1}) + if err != nil { + return err + } + + // rule length + err = binary.Write(encoder, binary.BigEndian, int64(len(ipRanges))) + if err != nil { + return err + } + + for _, ipRange := range ipRanges { + err = binary.Write(encoder, binary.BigEndian, ipRange.From().As16()) + if err != nil { + return err + } + + err = binary.Write(encoder, binary.BigEndian, ipRange.To().As16()) + if err != nil { + return err + } + } + + return nil +} |
