1 module nvimhost.client;
2 
3 alias Buffer = int;
4 alias Window = int;
5 alias Tabpage = int;
6 alias NullInt = int;
7 
8 enum MsgPackNullValue = 0xc0;
9 
10 enum NeoExtType {
11     Buffer,
12     Window,
13     Tabpage
14 }
15 
16 enum MsgKind {
17     request = 0,
18     response = 1,
19     notify = 2,
20 }
21 
22 struct Msg(T...) {
23     import std.typecons;
24 
25     int kind;
26     ulong id;
27     string method;
28     Tuple!(T) args;
29 }
30 
31 struct MsgAsync(T...) {
32     import std.typecons;
33 
34     int kind;
35     string method;
36     Tuple!(T) args;
37 }
38 
39 struct MsgResponse(T) {
40     int kind;
41     ulong id;
42     NullInt nullInt;
43     T ret;
44 }
45 
46 struct MsgVariant {
47     import std.typecons;
48     import std.variant;
49 
50     int kind;
51     ulong id;
52     string method;
53     Variant[] args;
54 }
55 
56 enum MethodType {NvimFunc};
57 
58 struct MethodInfo {
59     MethodType type;
60     string name;
61 }
62 
63 struct NvimClient {
64     import std.socket;
65     import std.process : environment;
66     import std.typecons;
67     import core.time;
68     import core.thread;
69     import std.concurrency : thisTid, Tid, send, receiveOnly;
70     import std.experimental.logger;
71     import vibe.core.net;
72     import eventcore.driver : IOMode;
73     import nvimhost.util : genSrcAddr;
74 
75 public:
76     string nvimAddr;
77     bool logEnabled;
78 
79 private:
80     // Client main async connection
81     TCPConnection conn;
82     NetworkAddress netAddr;
83     string srcAddr;
84     immutable bufferSize = 4096;
85     static ubyte[bufferSize] buffer;
86 
87     /// Msgpack ID counter only used for synchronous messages.
88     ulong msgId = 2;
89 
90     /**
91     Replace nulls with a default value to facilitate unpacking.
92     */
93     void replaceNulls(ubyte[] array, ubyte newByte) {
94         foreach (ref item; array) {
95             if (item == MsgPackNullValue) {
96                 item = 0x0;
97             }
98         }
99     }
100 
101     unittest {
102         auto n = NvimClient();
103         ubyte[] arr = [0, 1, 2, MsgPackNullValue, 3, MsgPackNullValue];
104         n.replaceNulls(arr, 0);
105         assert(arr == [0, 1, 2, 0, 3, 0]);
106     }
107 
108     /**
109     Convert ubytes arry to int
110     */
111     int uBytesToInt(ubyte[] arr) const {
112         int value = 0;
113         for (size_t i = 0; i < arr.length; ++i) {
114             value += arr[i] << (8 * (arr.length - 1 - i));
115         }
116         return value;
117     }
118 
119     unittest {
120         auto n = NvimClient();
121         assert(n.uBytesToInt([1, 0]) == 256);
122         assert(n.uBytesToInt([1, 1, 0]) == 65792);
123         assert(n.uBytesToInt([1]) == 1);
124     }
125 
126     /**
127     Make a generic RPC call to Nvim. If the struct is MsgAsync it's
128     async, and the message is sent and no reply is ever received back (it's like
129     fire and forget). If it's sync (Msg struct), then Nvim will send a reply.
130     */
131     auto callRPC(Ret = int, Struct)(Struct s) {
132         import msgpack;
133         import std.string : indexOf;
134 
135         // connect if not connected yet
136         if (netAddr.family != AddressFamily.UNIX) {
137             this.connect();
138         }
139 
140         auto packed = pack(s);
141         conn.write(packed, IOMode.once);
142         static if (indexOf(Struct.stringof, "MsgAsync!") != -1) {
143             tracef(logEnabled, "Sent async request data:\n%(%x %)", packed);
144             return 0;
145         } else {
146             tracef(logEnabled, "Sending request %d: method: %s \n%(%x %)", s.id, s.method, packed);
147 
148             size_t nBytes = bufferSize;
149             ubyte[] data;
150             do {
151                 nBytes = conn.read(buffer, IOMode.once);
152                 tracef("Received nBytes %d", nBytes);
153                 data ~= buffer[0 .. nBytes];
154             }
155             while (nBytes >= bufferSize);
156 
157             tracef(logEnabled, "Received response (pre-unpack) :\n%(%x %)", data);
158             replaceNulls(data, NullInt.init);
159 
160             auto unpacked = unpack!(MsgResponse!Ret)(data);
161             tracef(logEnabled, "Received response (unpacked bytes replaced) %d:\n%(%x %)", unpacked.id, data);
162 
163             static if (Ret.stringof == "ExtValue[]") {
164                 if (unpacked.ret.length) {
165                     int[] nums;
166                     logf(logEnabled, "ExtType request %d:%s\n", unpacked.id, unpacked.ret);
167                     foreach (item; unpacked.ret) {
168                         if (item.data.length > 1) {
169                             nums ~= uBytesToInt(item.data[1 .. $]);
170                         } else {
171                             nums ~= uBytesToInt(item.data);
172                         }
173                     }
174                     return nums;
175                 }
176             } else static if (Ret.stringof == "ExtValue") {
177                 auto item = unpacked.ret;
178                 int num;
179                 if (item.data.length > 1) {
180                     num = uBytesToInt(item.data[1 .. $]);
181                 } else {
182                     num = uBytesToInt(item.data);
183                 }
184                 return num;
185             } else {
186                 return unpacked.ret;
187             }
188         }
189         assert(0);
190     }
191 
192         unittest {
193         auto n = NvimClient();
194         auto s = n.inspectMsgKind([0x94, 0x01, 0x01, MsgPackNullValue]);
195         assert(s.id == 0x01 && s.kind == MsgKind.response);
196         s = n.inspectMsgKind([0x94, 0x02, MsgPackNullValue]);
197         assert(s.kind == MsgKind.notify);
198         s = n.inspectMsgKind([0x94, 0x00, 0x02, MsgPackNullValue]);
199         assert(s.id == 0x02 && s.kind == MsgKind.request);
200         // 0x66 is the letter f
201         s = n.inspectMsgKind([0x94, 0x00, 0x03, 0xa2, 0x66, 0x66, 0x90]);
202         s = n.inspectMsgKind([0x94, 0x00, 0x03, 0xa2, 0x66, 0x66, 0x92, 0x1, 0xa1, 0x66]);
203         assert(s.id == 0x03 && s.kind == MsgKind.request && s.method == "ff" && s.args[0] == 1 && s.args[1] == "f");
204     }
205 
206 public:
207 
208     /**
209     Release resources.
210     */
211     void close() {
212         import std.file : exists, remove;
213         if (netAddr.family == AddressFamily.UNIX) {
214             conn.close();
215         }
216         if (srcAddr.length && exists(srcAddr)) {
217             remove(srcAddr);
218         }
219     }
220 
221     /**
222     Enable logging.
223     */
224     void enableLog() {
225         if (environment.get("NVIMHOST_LOG")) {
226            logEnabled = true;
227         }
228         // if this env var is not defined it'll log to stderr
229         string logFile = environment.get("NVIMHOST_LOG_FILE");
230         if (logFile) {
231             logEnabled = true;
232             sharedLog = new FileLogger(logFile);
233         }
234     }
235 
236     ~this(){
237         close();
238     }
239 
240     /**
241     Decode Nvim request/notifications method strings names, which has this format:
242 
243     pluginName:function:functionName
244     pluginName:command:commandName
245 
246     */
247     MethodInfo decodeMethod(string methodName) {
248         import std.array;
249         auto res = methodName.split(":");
250         if (res.length != 3) {
251            throw new Exception("The methodName is supposed to match this regex .+:function|command:.+");
252         }
253         switch(res[1]) {
254             case "function":
255                 return MethodInfo(MethodType.NvimFunc, res[2]);
256             default:
257                throw new Exception("Unsupported type received: " ~ res[1]);
258         }
259     }
260 
261     unittest {
262         auto c = NvimClient();
263         auto res = c.decodeMethod("pluginName:function:SomeFunction");
264         assert(res.type == MethodType.NvimFunc && res.name == "SomeFunction");
265     }
266 
267     /**
268     Inspect the first two bytes of the serialized message to figure out message type of Nvim.
269     */
270     auto inspectMsgKind(ubyte[] arr) {
271         import msgpack;
272         import std.variant;
273 
274         if (arr.length < 3) {
275             assert(false, "Truncated message received.");
276         }
277 
278         auto unpacker = StreamingUnpacker(cast(ubyte[]) null);
279         unpacker.feed(arr);
280         unpacker.execute();
281 
282         int msgType = -1;
283         int id;
284         string method;
285         auto myT = tuple();
286         Variant[] varArgs;
287 
288         /**
289         Recursively unpacks array lists.
290         Nvim wraps funcs and cmd args via RPC in nested lists issue #1929
291         */
292         void unpackArray(ref Variant[] varArr, Value[] arr) {
293             Variant arg;
294             foreach (item; arr) {
295                 switch(item.type) {
296                     // chances are for plugins cast(int) will be enough for most cases
297                     case Value.Type.unsigned:
298                         arg = cast(int) item.via.uinteger;
299                         varArr ~= arg;
300                         break;
301                     case Value.Type.signed:
302                         arg = cast(int) item.via.integer;
303                         varArr ~= arg;
304                         break;
305                     case Value.Type.boolean:
306                         arg = item.via.boolean;
307                         varArr ~= arg;
308                         break;
309                     case Value.Type.raw:
310                         arg = cast(string) item.via.raw;
311                         varArr ~= arg;
312                         break;
313                     case Value.Type.floating:
314                         arg = item.via.floating;
315                         varArr ~= arg;
316                         break;
317                     case Value.Type.array:
318                         tracef(logEnabled, "nested array");
319                         unpackArray(varArr, item.via.array);
320                         break;
321                     default:
322                         errorf("Nested type %s is not supported yet", item.type);
323                         break;
324                 }
325             }
326         }
327 
328         foreach (unpacked; unpacker.purge()) {
329             if (unpacked.type == Value.Type.unsigned) {
330                 if (unpacked.type == Value.Type.unsigned ) {
331                     if (msgType == -1) {
332                         msgType = cast(int) unpacked.via.uinteger;
333                     } else {
334                         id = cast(int) unpacked.via.uinteger;
335                     }
336                 }
337             } else if (unpacked.type == Value.Type.raw) {
338                 method = cast (string)unpacked.via.raw;
339             } else if (unpacked.type == Value.Type.array) {
340                 unpackArray(varArgs, unpacked.via.array);
341             } else if (unpacked.type == Value.Type.nil) {
342                 // incoming nulls don't matter.
343                 continue;
344             } else {
345                 errorf("Type %s is not supported as a response param yet", unpacked.type);
346             }
347         }
348         auto res = MsgVariant(msgType, id, method, varArgs);
349         assert(msgType >= 0, "Couldn't parse message type.");
350         assert(id >= 0, "Couldn't parse message id.");
351 
352         return res;
353     }
354 
355 
356     /**
357     Asynchronous call returns immediatetly after serialializing the data over RPC.
358     */
359     auto callAsync(Ret, T...)(string cmd, T args) if (Ret.stringof == "void") {
360         import std.traits;
361 
362         auto myT = tuple(args);
363         auto msgAsync = MsgAsync!(myT.Types)(MsgKind.notify, cmd, myT);
364         callRPC!(Ret)(msgAsync);
365     }
366 
367     /**
368     Synchronous call.
369     */
370     auto call(Ret = int, T...)(string cmd, T args) {
371         import std.traits;
372 
373         auto myT = tuple(args);
374         auto msg = Msg!(myT.Types)(MsgKind.request, ++msgId, cmd, myT);
375 
376         static if (Ret.stringof == "void") {
377             auto res = callRPC!(int)(msg);
378         } else {
379             return callRPC!(Ret)(msg);
380         }
381     }
382 
383     /**
384     Open an async TCP connection handler to Nvim using UnixAddress.
385 
386     If NVIM_LISTEN_ADDRESS environment variable is not set throws
387     NvimListenAddressException.
388 
389     */
390     void connect() {
391         import std.path;
392 
393         this.nvimAddr = environment.get("NVIM_LISTEN_ADDRESS", "");
394         if (nvimAddr == "") {
395             throw new NvimListenAddressException("Couldn't get NVIM_LISTEN_ADDRESS, is nvim running?");
396         }
397 
398         auto unixAddr = new UnixAddress(nvimAddr);
399         netAddr = NetworkAddress(unixAddr);
400         srcAddr = genSrcAddr();
401         auto netSrcAddr = NetworkAddress(new UnixAddress(srcAddr));
402         conn = connectTCP(netAddr, netSrcAddr);
403         conn.keepAlive(true);
404         tracef(logEnabled, "Main thread connected to nvim");
405     }
406 }
407 
408 class NvimListenAddressException : Exception {
409     this(string msg, string file = __FILE__, size_t line = __LINE__) {
410         super(msg, file, line);
411     }
412 }