xen/pvcalls: implement accept command
[linux-2.6-block.git] / drivers / xen / pvcalls-front.c
1 /*
2  * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  */
14
15 #include <linux/module.h>
16 #include <linux/net.h>
17 #include <linux/socket.h>
18
19 #include <net/sock.h>
20
21 #include <xen/events.h>
22 #include <xen/grant_table.h>
23 #include <xen/xen.h>
24 #include <xen/xenbus.h>
25 #include <xen/interface/io/pvcalls.h>
26
27 #include "pvcalls-front.h"
28
29 #define PVCALLS_INVALID_ID UINT_MAX
30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
32
33 struct pvcalls_bedata {
34         struct xen_pvcalls_front_ring ring;
35         grant_ref_t ref;
36         int irq;
37
38         struct list_head socket_mappings;
39         spinlock_t socket_lock;
40
41         wait_queue_head_t inflight_req;
42         struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
43 };
44 /* Only one front/back connection supported. */
45 static struct xenbus_device *pvcalls_front_dev;
46 static atomic_t pvcalls_refcount;
47
48 /* first increment refcount, then proceed */
49 #define pvcalls_enter() {               \
50         atomic_inc(&pvcalls_refcount);      \
51 }
52
53 /* first complete other operations, then decrement refcount */
54 #define pvcalls_exit() {                \
55         atomic_dec(&pvcalls_refcount);      \
56 }
57
58 struct sock_mapping {
59         bool active_socket;
60         struct list_head list;
61         struct socket *sock;
62         union {
63                 struct {
64                         int irq;
65                         grant_ref_t ref;
66                         struct pvcalls_data_intf *ring;
67                         struct pvcalls_data data;
68                         struct mutex in_mutex;
69                         struct mutex out_mutex;
70
71                         wait_queue_head_t inflight_conn_req;
72                 } active;
73                 struct {
74                 /* Socket status */
75 #define PVCALLS_STATUS_UNINITALIZED  0
76 #define PVCALLS_STATUS_BIND          1
77 #define PVCALLS_STATUS_LISTEN        2
78                         uint8_t status;
79                 /*
80                  * Internal state-machine flags.
81                  * Only one accept operation can be inflight for a socket.
82                  * Only one poll operation can be inflight for a given socket.
83                  */
84 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
85                         uint8_t flags;
86                         uint32_t inflight_req_id;
87                         struct sock_mapping *accept_map;
88                         wait_queue_head_t inflight_accept_req;
89                 } passive;
90         };
91 };
92
93 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
94 {
95         *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
96         if (RING_FULL(&bedata->ring) ||
97             bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
98                 return -EAGAIN;
99         return 0;
100 }
101
102 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
103 {
104         struct xenbus_device *dev = dev_id;
105         struct pvcalls_bedata *bedata;
106         struct xen_pvcalls_response *rsp;
107         uint8_t *src, *dst;
108         int req_id = 0, more = 0, done = 0;
109
110         if (dev == NULL)
111                 return IRQ_HANDLED;
112
113         pvcalls_enter();
114         bedata = dev_get_drvdata(&dev->dev);
115         if (bedata == NULL) {
116                 pvcalls_exit();
117                 return IRQ_HANDLED;
118         }
119
120 again:
121         while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
122                 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
123
124                 req_id = rsp->req_id;
125                 dst = (uint8_t *)&bedata->rsp[req_id] + sizeof(rsp->req_id);
126                 src = (uint8_t *)rsp + sizeof(rsp->req_id);
127                 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
128                 /*
129                  * First copy the rest of the data, then req_id. It is
130                  * paired with the barrier when accessing bedata->rsp.
131                  */
132                 smp_wmb();
133                 bedata->rsp[req_id].req_id = rsp->req_id;
134
135                 done = 1;
136                 bedata->ring.rsp_cons++;
137         }
138
139         RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
140         if (more)
141                 goto again;
142         if (done)
143                 wake_up(&bedata->inflight_req);
144         pvcalls_exit();
145         return IRQ_HANDLED;
146 }
147
148 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
149                                    struct sock_mapping *map)
150 {
151 }
152
153 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
154 {
155         struct sock_mapping *map = sock_map;
156
157         if (map == NULL)
158                 return IRQ_HANDLED;
159
160         wake_up_interruptible(&map->active.inflight_conn_req);
161
162         return IRQ_HANDLED;
163 }
164
165 int pvcalls_front_socket(struct socket *sock)
166 {
167         struct pvcalls_bedata *bedata;
168         struct sock_mapping *map = NULL;
169         struct xen_pvcalls_request *req;
170         int notify, req_id, ret;
171
172         /*
173          * PVCalls only supports domain AF_INET,
174          * type SOCK_STREAM and protocol 0 sockets for now.
175          *
176          * Check socket type here, AF_INET and protocol checks are done
177          * by the caller.
178          */
179         if (sock->type != SOCK_STREAM)
180                 return -EOPNOTSUPP;
181
182         pvcalls_enter();
183         if (!pvcalls_front_dev) {
184                 pvcalls_exit();
185                 return -EACCES;
186         }
187         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
188
189         map = kzalloc(sizeof(*map), GFP_KERNEL);
190         if (map == NULL) {
191                 pvcalls_exit();
192                 return -ENOMEM;
193         }
194
195         spin_lock(&bedata->socket_lock);
196
197         ret = get_request(bedata, &req_id);
198         if (ret < 0) {
199                 kfree(map);
200                 spin_unlock(&bedata->socket_lock);
201                 pvcalls_exit();
202                 return ret;
203         }
204
205         /*
206          * sock->sk->sk_send_head is not used for ip sockets: reuse the
207          * field to store a pointer to the struct sock_mapping
208          * corresponding to the socket. This way, we can easily get the
209          * struct sock_mapping from the struct socket.
210          */
211         sock->sk->sk_send_head = (void *)map;
212         list_add_tail(&map->list, &bedata->socket_mappings);
213
214         req = RING_GET_REQUEST(&bedata->ring, req_id);
215         req->req_id = req_id;
216         req->cmd = PVCALLS_SOCKET;
217         req->u.socket.id = (uintptr_t) map;
218         req->u.socket.domain = AF_INET;
219         req->u.socket.type = SOCK_STREAM;
220         req->u.socket.protocol = IPPROTO_IP;
221
222         bedata->ring.req_prod_pvt++;
223         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
224         spin_unlock(&bedata->socket_lock);
225         if (notify)
226                 notify_remote_via_irq(bedata->irq);
227
228         wait_event(bedata->inflight_req,
229                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
230
231         /* read req_id, then the content */
232         smp_rmb();
233         ret = bedata->rsp[req_id].ret;
234         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
235
236         pvcalls_exit();
237         return ret;
238 }
239
240 static int create_active(struct sock_mapping *map, int *evtchn)
241 {
242         void *bytes;
243         int ret = -ENOMEM, irq = -1, i;
244
245         *evtchn = -1;
246         init_waitqueue_head(&map->active.inflight_conn_req);
247
248         map->active.ring = (struct pvcalls_data_intf *)
249                 __get_free_page(GFP_KERNEL | __GFP_ZERO);
250         if (map->active.ring == NULL)
251                 goto out_error;
252         map->active.ring->ring_order = PVCALLS_RING_ORDER;
253         bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
254                                         PVCALLS_RING_ORDER);
255         if (bytes == NULL)
256                 goto out_error;
257         for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
258                 map->active.ring->ref[i] = gnttab_grant_foreign_access(
259                         pvcalls_front_dev->otherend_id,
260                         pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
261
262         map->active.ref = gnttab_grant_foreign_access(
263                 pvcalls_front_dev->otherend_id,
264                 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
265
266         map->active.data.in = bytes;
267         map->active.data.out = bytes +
268                 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
269
270         ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
271         if (ret)
272                 goto out_error;
273         irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
274                                         0, "pvcalls-frontend", map);
275         if (irq < 0) {
276                 ret = irq;
277                 goto out_error;
278         }
279
280         map->active.irq = irq;
281         map->active_socket = true;
282         mutex_init(&map->active.in_mutex);
283         mutex_init(&map->active.out_mutex);
284
285         return 0;
286
287 out_error:
288         if (irq >= 0)
289                 unbind_from_irqhandler(irq, map);
290         else if (*evtchn >= 0)
291                 xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
292         kfree(map->active.data.in);
293         kfree(map->active.ring);
294         return ret;
295 }
296
297 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
298                                 int addr_len, int flags)
299 {
300         struct pvcalls_bedata *bedata;
301         struct sock_mapping *map = NULL;
302         struct xen_pvcalls_request *req;
303         int notify, req_id, ret, evtchn;
304
305         if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
306                 return -EOPNOTSUPP;
307
308         pvcalls_enter();
309         if (!pvcalls_front_dev) {
310                 pvcalls_exit();
311                 return -ENOTCONN;
312         }
313
314         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
315
316         map = (struct sock_mapping *)sock->sk->sk_send_head;
317         if (!map) {
318                 pvcalls_exit();
319                 return -ENOTSOCK;
320         }
321
322         spin_lock(&bedata->socket_lock);
323         ret = get_request(bedata, &req_id);
324         if (ret < 0) {
325                 spin_unlock(&bedata->socket_lock);
326                 pvcalls_exit();
327                 return ret;
328         }
329         ret = create_active(map, &evtchn);
330         if (ret < 0) {
331                 spin_unlock(&bedata->socket_lock);
332                 pvcalls_exit();
333                 return ret;
334         }
335
336         req = RING_GET_REQUEST(&bedata->ring, req_id);
337         req->req_id = req_id;
338         req->cmd = PVCALLS_CONNECT;
339         req->u.connect.id = (uintptr_t)map;
340         req->u.connect.len = addr_len;
341         req->u.connect.flags = flags;
342         req->u.connect.ref = map->active.ref;
343         req->u.connect.evtchn = evtchn;
344         memcpy(req->u.connect.addr, addr, sizeof(*addr));
345
346         map->sock = sock;
347
348         bedata->ring.req_prod_pvt++;
349         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
350         spin_unlock(&bedata->socket_lock);
351
352         if (notify)
353                 notify_remote_via_irq(bedata->irq);
354
355         wait_event(bedata->inflight_req,
356                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
357
358         /* read req_id, then the content */
359         smp_rmb();
360         ret = bedata->rsp[req_id].ret;
361         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
362         pvcalls_exit();
363         return ret;
364 }
365
366 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
367 {
368         struct pvcalls_bedata *bedata;
369         struct sock_mapping *map = NULL;
370         struct xen_pvcalls_request *req;
371         int notify, req_id, ret;
372
373         if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
374                 return -EOPNOTSUPP;
375
376         pvcalls_enter();
377         if (!pvcalls_front_dev) {
378                 pvcalls_exit();
379                 return -ENOTCONN;
380         }
381         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
382
383         map = (struct sock_mapping *) sock->sk->sk_send_head;
384         if (map == NULL) {
385                 pvcalls_exit();
386                 return -ENOTSOCK;
387         }
388
389         spin_lock(&bedata->socket_lock);
390         ret = get_request(bedata, &req_id);
391         if (ret < 0) {
392                 spin_unlock(&bedata->socket_lock);
393                 pvcalls_exit();
394                 return ret;
395         }
396         req = RING_GET_REQUEST(&bedata->ring, req_id);
397         req->req_id = req_id;
398         map->sock = sock;
399         req->cmd = PVCALLS_BIND;
400         req->u.bind.id = (uintptr_t)map;
401         memcpy(req->u.bind.addr, addr, sizeof(*addr));
402         req->u.bind.len = addr_len;
403
404         init_waitqueue_head(&map->passive.inflight_accept_req);
405
406         map->active_socket = false;
407
408         bedata->ring.req_prod_pvt++;
409         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
410         spin_unlock(&bedata->socket_lock);
411         if (notify)
412                 notify_remote_via_irq(bedata->irq);
413
414         wait_event(bedata->inflight_req,
415                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
416
417         /* read req_id, then the content */
418         smp_rmb();
419         ret = bedata->rsp[req_id].ret;
420         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
421
422         map->passive.status = PVCALLS_STATUS_BIND;
423         pvcalls_exit();
424         return 0;
425 }
426
427 int pvcalls_front_listen(struct socket *sock, int backlog)
428 {
429         struct pvcalls_bedata *bedata;
430         struct sock_mapping *map;
431         struct xen_pvcalls_request *req;
432         int notify, req_id, ret;
433
434         pvcalls_enter();
435         if (!pvcalls_front_dev) {
436                 pvcalls_exit();
437                 return -ENOTCONN;
438         }
439         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
440
441         map = (struct sock_mapping *) sock->sk->sk_send_head;
442         if (!map) {
443                 pvcalls_exit();
444                 return -ENOTSOCK;
445         }
446
447         if (map->passive.status != PVCALLS_STATUS_BIND) {
448                 pvcalls_exit();
449                 return -EOPNOTSUPP;
450         }
451
452         spin_lock(&bedata->socket_lock);
453         ret = get_request(bedata, &req_id);
454         if (ret < 0) {
455                 spin_unlock(&bedata->socket_lock);
456                 pvcalls_exit();
457                 return ret;
458         }
459         req = RING_GET_REQUEST(&bedata->ring, req_id);
460         req->req_id = req_id;
461         req->cmd = PVCALLS_LISTEN;
462         req->u.listen.id = (uintptr_t) map;
463         req->u.listen.backlog = backlog;
464
465         bedata->ring.req_prod_pvt++;
466         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
467         spin_unlock(&bedata->socket_lock);
468         if (notify)
469                 notify_remote_via_irq(bedata->irq);
470
471         wait_event(bedata->inflight_req,
472                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
473
474         /* read req_id, then the content */
475         smp_rmb();
476         ret = bedata->rsp[req_id].ret;
477         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
478
479         map->passive.status = PVCALLS_STATUS_LISTEN;
480         pvcalls_exit();
481         return ret;
482 }
483
484 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
485 {
486         struct pvcalls_bedata *bedata;
487         struct sock_mapping *map;
488         struct sock_mapping *map2 = NULL;
489         struct xen_pvcalls_request *req;
490         int notify, req_id, ret, evtchn, nonblock;
491
492         pvcalls_enter();
493         if (!pvcalls_front_dev) {
494                 pvcalls_exit();
495                 return -ENOTCONN;
496         }
497         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
498
499         map = (struct sock_mapping *) sock->sk->sk_send_head;
500         if (!map) {
501                 pvcalls_exit();
502                 return -ENOTSOCK;
503         }
504
505         if (map->passive.status != PVCALLS_STATUS_LISTEN) {
506                 pvcalls_exit();
507                 return -EINVAL;
508         }
509
510         nonblock = flags & SOCK_NONBLOCK;
511         /*
512          * Backend only supports 1 inflight accept request, will return
513          * errors for the others
514          */
515         if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
516                              (void *)&map->passive.flags)) {
517                 req_id = READ_ONCE(map->passive.inflight_req_id);
518                 if (req_id != PVCALLS_INVALID_ID &&
519                     READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
520                         map2 = map->passive.accept_map;
521                         goto received;
522                 }
523                 if (nonblock) {
524                         pvcalls_exit();
525                         return -EAGAIN;
526                 }
527                 if (wait_event_interruptible(map->passive.inflight_accept_req,
528                         !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
529                                           (void *)&map->passive.flags))) {
530                         pvcalls_exit();
531                         return -EINTR;
532                 }
533         }
534
535         spin_lock(&bedata->socket_lock);
536         ret = get_request(bedata, &req_id);
537         if (ret < 0) {
538                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
539                           (void *)&map->passive.flags);
540                 spin_unlock(&bedata->socket_lock);
541                 pvcalls_exit();
542                 return ret;
543         }
544         map2 = kzalloc(sizeof(*map2), GFP_KERNEL);
545         if (map2 == NULL) {
546                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
547                           (void *)&map->passive.flags);
548                 spin_unlock(&bedata->socket_lock);
549                 pvcalls_exit();
550                 return -ENOMEM;
551         }
552         ret = create_active(map2, &evtchn);
553         if (ret < 0) {
554                 kfree(map2);
555                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
556                           (void *)&map->passive.flags);
557                 spin_unlock(&bedata->socket_lock);
558                 pvcalls_exit();
559                 return ret;
560         }
561         list_add_tail(&map2->list, &bedata->socket_mappings);
562
563         req = RING_GET_REQUEST(&bedata->ring, req_id);
564         req->req_id = req_id;
565         req->cmd = PVCALLS_ACCEPT;
566         req->u.accept.id = (uintptr_t) map;
567         req->u.accept.ref = map2->active.ref;
568         req->u.accept.id_new = (uintptr_t) map2;
569         req->u.accept.evtchn = evtchn;
570         map->passive.accept_map = map2;
571
572         bedata->ring.req_prod_pvt++;
573         RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
574         spin_unlock(&bedata->socket_lock);
575         if (notify)
576                 notify_remote_via_irq(bedata->irq);
577         /* We could check if we have received a response before returning. */
578         if (nonblock) {
579                 WRITE_ONCE(map->passive.inflight_req_id, req_id);
580                 pvcalls_exit();
581                 return -EAGAIN;
582         }
583
584         if (wait_event_interruptible(bedata->inflight_req,
585                 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
586                 pvcalls_exit();
587                 return -EINTR;
588         }
589         /* read req_id, then the content */
590         smp_rmb();
591
592 received:
593         map2->sock = newsock;
594         newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL);
595         if (!newsock->sk) {
596                 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
597                 map->passive.inflight_req_id = PVCALLS_INVALID_ID;
598                 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
599                           (void *)&map->passive.flags);
600                 pvcalls_front_free_map(bedata, map2);
601                 pvcalls_exit();
602                 return -ENOMEM;
603         }
604         newsock->sk->sk_send_head = (void *)map2;
605
606         ret = bedata->rsp[req_id].ret;
607         bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
608         map->passive.inflight_req_id = PVCALLS_INVALID_ID;
609
610         clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
611         wake_up(&map->passive.inflight_accept_req);
612
613         pvcalls_exit();
614         return ret;
615 }
616
617 static const struct xenbus_device_id pvcalls_front_ids[] = {
618         { "pvcalls" },
619         { "" }
620 };
621
622 static int pvcalls_front_remove(struct xenbus_device *dev)
623 {
624         struct pvcalls_bedata *bedata;
625         struct sock_mapping *map = NULL, *n;
626
627         bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
628         dev_set_drvdata(&dev->dev, NULL);
629         pvcalls_front_dev = NULL;
630         if (bedata->irq >= 0)
631                 unbind_from_irqhandler(bedata->irq, dev);
632
633         list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
634                 map->sock->sk->sk_send_head = NULL;
635                 if (map->active_socket) {
636                         map->active.ring->in_error = -EBADF;
637                         wake_up_interruptible(&map->active.inflight_conn_req);
638                 }
639         }
640
641         smp_mb();
642         while (atomic_read(&pvcalls_refcount) > 0)
643                 cpu_relax();
644         list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
645                 if (map->active_socket) {
646                         /* No need to lock, refcount is 0 */
647                         pvcalls_front_free_map(bedata, map);
648                 } else {
649                         list_del(&map->list);
650                         kfree(map);
651                 }
652         }
653         if (bedata->ref >= 0)
654                 gnttab_end_foreign_access(bedata->ref, 0, 0);
655         kfree(bedata->ring.sring);
656         kfree(bedata);
657         xenbus_switch_state(dev, XenbusStateClosed);
658         return 0;
659 }
660
661 static int pvcalls_front_probe(struct xenbus_device *dev,
662                           const struct xenbus_device_id *id)
663 {
664         int ret = -ENOMEM, evtchn, i;
665         unsigned int max_page_order, function_calls, len;
666         char *versions;
667         grant_ref_t gref_head = 0;
668         struct xenbus_transaction xbt;
669         struct pvcalls_bedata *bedata = NULL;
670         struct xen_pvcalls_sring *sring;
671
672         if (pvcalls_front_dev != NULL) {
673                 dev_err(&dev->dev, "only one PV Calls connection supported\n");
674                 return -EINVAL;
675         }
676
677         versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
678         if (!len)
679                 return -EINVAL;
680         if (strcmp(versions, "1")) {
681                 kfree(versions);
682                 return -EINVAL;
683         }
684         kfree(versions);
685         max_page_order = xenbus_read_unsigned(dev->otherend,
686                                               "max-page-order", 0);
687         if (max_page_order < PVCALLS_RING_ORDER)
688                 return -ENODEV;
689         function_calls = xenbus_read_unsigned(dev->otherend,
690                                               "function-calls", 0);
691         /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
692         if (function_calls != 1)
693                 return -ENODEV;
694         pr_info("%s max-page-order is %u\n", __func__, max_page_order);
695
696         bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
697         if (!bedata)
698                 return -ENOMEM;
699
700         dev_set_drvdata(&dev->dev, bedata);
701         pvcalls_front_dev = dev;
702         init_waitqueue_head(&bedata->inflight_req);
703         INIT_LIST_HEAD(&bedata->socket_mappings);
704         spin_lock_init(&bedata->socket_lock);
705         bedata->irq = -1;
706         bedata->ref = -1;
707
708         for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
709                 bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
710
711         sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
712                                                              __GFP_ZERO);
713         if (!sring)
714                 goto error;
715         SHARED_RING_INIT(sring);
716         FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
717
718         ret = xenbus_alloc_evtchn(dev, &evtchn);
719         if (ret)
720                 goto error;
721
722         bedata->irq = bind_evtchn_to_irqhandler(evtchn,
723                                                 pvcalls_front_event_handler,
724                                                 0, "pvcalls-frontend", dev);
725         if (bedata->irq < 0) {
726                 ret = bedata->irq;
727                 goto error;
728         }
729
730         ret = gnttab_alloc_grant_references(1, &gref_head);
731         if (ret < 0)
732                 goto error;
733         bedata->ref = gnttab_claim_grant_reference(&gref_head);
734         if (bedata->ref < 0) {
735                 ret = bedata->ref;
736                 goto error;
737         }
738         gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
739                                         virt_to_gfn((void *)sring), 0);
740
741  again:
742         ret = xenbus_transaction_start(&xbt);
743         if (ret) {
744                 xenbus_dev_fatal(dev, ret, "starting transaction");
745                 goto error;
746         }
747         ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
748         if (ret)
749                 goto error_xenbus;
750         ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
751         if (ret)
752                 goto error_xenbus;
753         ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
754                             evtchn);
755         if (ret)
756                 goto error_xenbus;
757         ret = xenbus_transaction_end(xbt, 0);
758         if (ret) {
759                 if (ret == -EAGAIN)
760                         goto again;
761                 xenbus_dev_fatal(dev, ret, "completing transaction");
762                 goto error;
763         }
764         xenbus_switch_state(dev, XenbusStateInitialised);
765
766         return 0;
767
768  error_xenbus:
769         xenbus_transaction_end(xbt, 1);
770         xenbus_dev_fatal(dev, ret, "writing xenstore");
771  error:
772         pvcalls_front_remove(dev);
773         return ret;
774 }
775
776 static void pvcalls_front_changed(struct xenbus_device *dev,
777                             enum xenbus_state backend_state)
778 {
779         switch (backend_state) {
780         case XenbusStateReconfiguring:
781         case XenbusStateReconfigured:
782         case XenbusStateInitialising:
783         case XenbusStateInitialised:
784         case XenbusStateUnknown:
785                 break;
786
787         case XenbusStateInitWait:
788                 break;
789
790         case XenbusStateConnected:
791                 xenbus_switch_state(dev, XenbusStateConnected);
792                 break;
793
794         case XenbusStateClosed:
795                 if (dev->state == XenbusStateClosed)
796                         break;
797                 /* Missed the backend's CLOSING state -- fallthrough */
798         case XenbusStateClosing:
799                 xenbus_frontend_closed(dev);
800                 break;
801         }
802 }
803
804 static struct xenbus_driver pvcalls_front_driver = {
805         .ids = pvcalls_front_ids,
806         .probe = pvcalls_front_probe,
807         .remove = pvcalls_front_remove,
808         .otherend_changed = pvcalls_front_changed,
809 };
810
811 static int __init pvcalls_frontend_init(void)
812 {
813         if (!xen_domain())
814                 return -ENODEV;
815
816         pr_info("Initialising Xen pvcalls frontend driver\n");
817
818         return xenbus_register_frontend(&pvcalls_front_driver);
819 }
820
821 module_init(pvcalls_frontend_init);