summaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
authorLoyalsoldier <[email protected]>2024-08-11 11:59:21 +0800
committerLoyalsoldier <[email protected]>2024-08-11 11:59:21 +0800
commitc2db35721738e91189cfb8a26fb3e97bcceb000d (patch)
tree52f207d6ffc8bc7471f87539b8075df12bf56b71 /plugin
parent87dcc81c764c3c28fea6772df4b1352a6a3f89fe (diff)
Feat: support mihomo MRS format as input & output
Diffstat (limited to 'plugin')
-rw-r--r--plugin/mihomo/mrs_in.go337
-rw-r--r--plugin/mihomo/mrs_out.go248
2 files changed, 585 insertions, 0 deletions
diff --git a/plugin/mihomo/mrs_in.go b/plugin/mihomo/mrs_in.go
new file mode 100644
index 00000000..8b5eeefb
--- /dev/null
+++ b/plugin/mihomo/mrs_in.go
@@ -0,0 +1,337 @@
+package mihomo
+
+import (
+ "bytes"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "github.com/Loyalsoldier/geoip/lib"
+ "github.com/klauspost/compress/zstd"
+ "go4.org/netipx"
+)
+
+var mrsMagicBytes = [4]byte{'M', 'R', 'S', 1} // MRSv1
+
+const (
+ typeMRSIn = "mihomoMRS"
+ descMRSIn = "Convert mihomo MRS data to other formats"
+)
+
+func init() {
+ lib.RegisterInputConfigCreator(typeMRSIn, func(action lib.Action, data json.RawMessage) (lib.InputConverter, error) {
+ return newMRSIn(action, data)
+ })
+ lib.RegisterInputConverter(typeMRSIn, &mrsIn{
+ Description: descMRSIn,
+ })
+}
+
+func newMRSIn(action lib.Action, data json.RawMessage) (lib.InputConverter, error) {
+ var tmp struct {
+ Name string `json:"name"`
+ URI string `json:"uri"`
+ InputDir string `json:"inputDir"`
+ OnlyIPType lib.IPType `json:"onlyIPType"`
+ }
+
+ if len(data) > 0 {
+ if err := json.Unmarshal(data, &tmp); err != nil {
+ return nil, err
+ }
+ }
+
+ if tmp.Name == "" && tmp.URI == "" && tmp.InputDir == "" {
+ return nil, fmt.Errorf("type %s | action %s missing inputdir or name or uri", typeMRSIn, action)
+ }
+
+ if (tmp.Name != "" && tmp.URI == "") || (tmp.Name == "" && tmp.URI != "") {
+ return nil, fmt.Errorf("type %s | action %s name & uri must be specified together", typeMRSIn, action)
+ }
+
+ return &mrsIn{
+ Type: typeMRSIn,
+ Action: action,
+ Description: descMRSIn,
+ Name: tmp.Name,
+ URI: tmp.URI,
+ InputDir: tmp.InputDir,
+ OnlyIPType: tmp.OnlyIPType,
+ }, nil
+}
+
+type mrsIn struct {
+ Type string
+ Action lib.Action
+ Description string
+ Name string
+ URI string
+ InputDir string
+ OnlyIPType lib.IPType
+}
+
+func (m *mrsIn) GetType() string {
+ return m.Type
+}
+
+func (m *mrsIn) GetAction() lib.Action {
+ return m.Action
+}
+
+func (m *mrsIn) GetDescription() string {
+ return m.Description
+}
+
+func (m *mrsIn) Input(container lib.Container) (lib.Container, error) {
+ entries := make(map[string]*lib.Entry)
+ var err error
+
+ switch {
+ case m.InputDir != "":
+ err = m.walkDir(m.InputDir, entries)
+ case m.Name != "" && m.URI != "":
+ switch {
+ case strings.HasPrefix(strings.ToLower(m.URI), "http://"), strings.HasPrefix(strings.ToLower(m.URI), "https://"):
+ err = m.walkRemoteFile(m.URI, m.Name, entries)
+ default:
+ err = m.walkLocalFile(m.URI, m.Name, entries)
+ }
+ default:
+ return nil, fmt.Errorf("config missing argument inputDir or name or uri")
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ var ignoreIPType lib.IgnoreIPOption
+ switch m.OnlyIPType {
+ case lib.IPv4:
+ ignoreIPType = lib.IgnoreIPv6
+ case lib.IPv6:
+ ignoreIPType = lib.IgnoreIPv4
+ }
+
+ if len(entries) == 0 {
+ return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", m.Type, m.Action)
+ }
+
+ for _, entry := range entries {
+ switch m.Action {
+ case lib.ActionAdd:
+ if err := container.Add(entry, ignoreIPType); err != nil {
+ return nil, err
+ }
+ case lib.ActionRemove:
+ if err := container.Remove(entry, lib.CaseRemovePrefix, ignoreIPType); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, lib.ErrUnknownAction
+ }
+ }
+
+ return container, nil
+}
+
+func (m *mrsIn) walkDir(dir string, entries map[string]*lib.Entry) error {
+ err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if info.IsDir() {
+ return nil
+ }
+
+ if err := m.walkLocalFile(path, "", entries); err != nil {
+ return err
+ }
+
+ return nil
+ })
+
+ return err
+}
+
+func (m *mrsIn) walkLocalFile(path, name string, entries map[string]*lib.Entry) error {
+ entryName := ""
+ name = strings.TrimSpace(name)
+ if name != "" {
+ entryName = name
+ } else {
+ entryName = filepath.Base(path)
+
+ // check filename
+ if !regexp.MustCompile(`^[a-zA-Z0-9_.\-]+$`).MatchString(entryName) {
+ return fmt.Errorf("filename %s cannot be entry name, please remove special characters in it", entryName)
+ }
+
+ // remove file extension but not hidden files of which filename starts with "."
+ dotIndex := strings.LastIndex(entryName, ".")
+ if dotIndex > 0 {
+ entryName = entryName[:dotIndex]
+ }
+ }
+
+ entryName = strings.ToUpper(entryName)
+ if _, found := entries[entryName]; found {
+ return fmt.Errorf("found duplicated list %s", entryName)
+ }
+
+ file, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ if err := m.generateEntries(entryName, file, entries); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *mrsIn) walkRemoteFile(url, name string, entries map[string]*lib.Entry) error {
+ resp, err := http.Get(url)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ return fmt.Errorf("failed to get remote file %s, http status code %d", url, resp.StatusCode)
+ }
+
+ if err := m.generateEntries(name, resp.Body, entries); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *mrsIn) generateEntries(name string, reader io.Reader, entries map[string]*lib.Entry) error {
+ name = strings.ToUpper(name)
+ entry, found := entries[name]
+ if !found {
+ entry = lib.NewEntry(name)
+ }
+
+ data, err := io.ReadAll(reader)
+ if err != nil {
+ return err
+ }
+
+ err = m.parseMRS(data, entry)
+ if err != nil {
+ return err
+ }
+
+ entries[name] = entry
+ return nil
+}
+
+func (m *mrsIn) parseMRS(data []byte, entry *lib.Entry) error {
+ reader, err := zstd.NewReader(bytes.NewReader(data))
+ if err != nil {
+ return err
+ }
+ defer reader.Close()
+
+ // header
+ var header [4]byte
+ _, err = io.ReadFull(reader, header[:])
+ if err != nil {
+ return err
+ }
+ if header != mrsMagicBytes {
+ return fmt.Errorf("invalid MRS format")
+ }
+
+ // behavior
+ var behavior [1]byte
+ _, err = io.ReadFull(reader, behavior[:])
+ if err != nil {
+ return err
+ }
+ if behavior[0] != byte(1) { // RuleBehavior IPCIDR = 1
+ return fmt.Errorf("invalid MRS IPCIDR data")
+ }
+
+ // count
+ var count int64
+ err = binary.Read(reader, binary.BigEndian, &count)
+ if err != nil {
+ return err
+ }
+
+ // extra (reserved for future using)
+ var length int64
+ err = binary.Read(reader, binary.BigEndian, &length)
+ if err != nil {
+ return err
+ }
+ if length < 0 {
+ return fmt.Errorf("invalid MRS extra length")
+ }
+ if length > 0 {
+ extra := make([]byte, length)
+ _, err = io.ReadFull(reader, extra)
+ if err != nil {
+ return err
+ }
+ }
+
+ //
+ // rules
+ //
+ // version
+ version := make([]byte, 1)
+ _, err = io.ReadFull(reader, version)
+ if err != nil {
+ return err
+ }
+ if version[0] != 1 {
+ return fmt.Errorf("invalid MRS rule version")
+ }
+
+ // rule length
+ var ruleLength int64
+ err = binary.Read(reader, binary.BigEndian, &ruleLength)
+ if err != nil {
+ return err
+ }
+ if ruleLength < 1 {
+ return fmt.Errorf("invalid MRS rule length")
+ }
+
+ for i := int64(0); i < ruleLength; i++ {
+ var a16 [16]byte
+ err = binary.Read(reader, binary.BigEndian, &a16)
+ if err != nil {
+ return err
+ }
+ from := netip.AddrFrom16(a16).Unmap()
+
+ err = binary.Read(reader, binary.BigEndian, &a16)
+ if err != nil {
+ return err
+ }
+ to := netip.AddrFrom16(a16).Unmap()
+
+ iprange := netipx.IPRangeFrom(from, to)
+ for _, prefix := range iprange.Prefixes() {
+ if err := entry.AddPrefix(prefix); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/plugin/mihomo/mrs_out.go b/plugin/mihomo/mrs_out.go
new file mode 100644
index 00000000..85717b30
--- /dev/null
+++ b/plugin/mihomo/mrs_out.go
@@ -0,0 +1,248 @@
+package mihomo
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "github.com/Loyalsoldier/geoip/lib"
+ "github.com/klauspost/compress/zstd"
+ "go4.org/netipx"
+)
+
+const (
+ typeMRSOut = "mihomoMRS"
+ descMRSOut = "Convert data to mihomo MRS format"
+)
+
+var (
+ defaultOutputDir = filepath.Join("./", "output", "mrs")
+)
+
+func init() {
+ lib.RegisterOutputConfigCreator(typeMRSOut, func(action lib.Action, data json.RawMessage) (lib.OutputConverter, error) {
+ return newMRSOut(action, data)
+ })
+ lib.RegisterOutputConverter(typeMRSOut, &mrsOut{
+ Description: descMRSOut,
+ })
+}
+
+func newMRSOut(action lib.Action, data json.RawMessage) (lib.OutputConverter, error) {
+ var tmp struct {
+ OutputDir string `json:"outputDir"`
+ Want []string `json:"wantedList"`
+ OnlyIPType lib.IPType `json:"onlyIPType"`
+ }
+
+ if len(data) > 0 {
+ if err := json.Unmarshal(data, &tmp); err != nil {
+ return nil, err
+ }
+ }
+
+ if tmp.OutputDir == "" {
+ tmp.OutputDir = defaultOutputDir
+ }
+
+ return &mrsOut{
+ Type: typeMRSOut,
+ Action: action,
+ Description: descMRSOut,
+ OutputDir: tmp.OutputDir,
+ Want: tmp.Want,
+ OnlyIPType: tmp.OnlyIPType,
+ }, nil
+}
+
+type mrsOut struct {
+ Type string
+ Action lib.Action
+ Description string
+ OutputDir string
+ Want []string
+ OnlyIPType lib.IPType
+}
+
+func (m *mrsOut) GetType() string {
+ return m.Type
+}
+
+func (m *mrsOut) GetAction() lib.Action {
+ return m.Action
+}
+
+func (m *mrsOut) GetDescription() string {
+ return m.Description
+}
+
+func (m *mrsOut) Output(container lib.Container) error {
+ // Filter want list
+ wantList := make([]string, 0, len(m.Want))
+ for _, want := range m.Want {
+ if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+ wantList = append(wantList, want)
+ }
+ }
+
+ switch len(wantList) {
+ case 0:
+ list := make([]string, 0, 300)
+ for entry := range container.Loop() {
+ list = append(list, entry.GetName())
+ }
+
+ // Sort the list
+ slices.Sort(list)
+
+ for _, name := range list {
+ entry, found := container.GetEntry(name)
+ if !found {
+ log.Printf("❌ entry %s not found", name)
+ continue
+ }
+ if err := m.generate(entry); err != nil {
+ return err
+ }
+ }
+
+ default:
+ // Sort the list
+ slices.Sort(wantList)
+
+ for _, name := range wantList {
+ entry, found := container.GetEntry(name)
+ if !found {
+ log.Printf("❌ entry %s not found", name)
+ continue
+ }
+
+ if err := m.generate(entry); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (m *mrsOut) generate(entry *lib.Entry) error {
+ var ipRanges []netipx.IPRange
+ var err error
+ switch m.OnlyIPType {
+ case lib.IPv4:
+ ipRanges, err = entry.MarshalIPRange(lib.IgnoreIPv6)
+ case lib.IPv6:
+ ipRanges, err = entry.MarshalIPRange(lib.IgnoreIPv4)
+ default:
+ ipRanges, err = entry.MarshalIPRange()
+ }
+ if err != nil {
+ return err
+ }
+
+ if len(ipRanges) == 0 {
+ return fmt.Errorf("entry %s has no CIDR", entry.GetName())
+ }
+
+ filename := strings.ToLower(entry.GetName()) + ".mrs"
+ if err := m.writeFile(filename, ipRanges); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *mrsOut) writeFile(filename string, ipRanges []netipx.IPRange) error {
+ if err := os.MkdirAll(m.OutputDir, 0755); err != nil {
+ return err
+ }
+
+ f, err := os.Create(filepath.Join(m.OutputDir, filename))
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ err = m.convertToMrs(ipRanges, f)
+ if err != nil {
+ return err
+ }
+
+ log.Printf("✅ [%s] %s --> %s", m.Type, filename, m.OutputDir)
+
+ return nil
+}
+
+func (m *mrsOut) convertToMrs(ipRanges []netipx.IPRange, w io.Writer) (err error) {
+ encoder, err := zstd.NewWriter(w)
+ if err != nil {
+ return err
+ }
+ defer encoder.Close()
+
+ // header
+ _, err = encoder.Write(mrsMagicBytes[:])
+ if err != nil {
+ return err
+ }
+
+ // behavior
+ _, err = encoder.Write([]byte{1}) // RuleBehavior IPCIDR = 1
+ if err != nil {
+ return err
+ }
+
+ // count
+ count := int64(len(ipRanges))
+ err = binary.Write(encoder, binary.BigEndian, count)
+ if err != nil {
+ return err
+ }
+
+ // extra (reserved for future using)
+ var extra []byte
+ err = binary.Write(encoder, binary.BigEndian, int64(len(extra)))
+ if err != nil {
+ return err
+ }
+ _, err = encoder.Write(extra)
+ if err != nil {
+ return err
+ }
+
+ //
+ // rule
+ //
+ // version
+ _, err = encoder.Write([]byte{1})
+ if err != nil {
+ return err
+ }
+
+ // rule length
+ err = binary.Write(encoder, binary.BigEndian, int64(len(ipRanges)))
+ if err != nil {
+ return err
+ }
+
+ for _, ipRange := range ipRanges {
+ err = binary.Write(encoder, binary.BigEndian, ipRange.From().As16())
+ if err != nil {
+ return err
+ }
+
+ err = binary.Write(encoder, binary.BigEndian, ipRange.To().As16())
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}